ops.py 98 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import operator
  5. from typing import * # noqa: F403
  6. from typing import Optional
  7. import torch
  8. import torch.nn.functional as F
  9. from torch.fx.operator_schemas import normalize_function
  10. from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
  11. from .nested_tensor import NestedTensor
  12. __all__: list[Any] = []
  13. JAGGED_OPS_TABLE: Dict[Any, Any] = {}
  14. def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
  15. from torch._prims_common import canonicalize_dims
  16. if isinstance(dim, (tuple, list)):
  17. output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
  18. # ensure no duplicates, which can result from both batch and ragged mapping to 0
  19. return type(output)(dict.fromkeys(output))
  20. if canonicalize:
  21. dim = canonicalize_dims(ndim, dim)
  22. assert dim >= 0 and dim < ndim
  23. # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
  24. # For other dims, subtract 1 to convert to inner space.
  25. return ragged_dim - 1 if dim == 0 else dim - 1
  26. def _wrap_jagged_dim(
  27. ndim,
  28. dim,
  29. ragged_dim,
  30. op_name,
  31. convert_to_inner_dim=True,
  32. allow_ragged_dim=False,
  33. allow_batch_dim=False,
  34. ):
  35. from torch._prims_common import canonicalize_dims
  36. wrapped = canonicalize_dims(ndim, dim)
  37. if wrapped == ragged_dim and not allow_ragged_dim:
  38. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
  39. elif wrapped == 0 and not allow_batch_dim:
  40. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
  41. ret = (
  42. _outer_to_inner_dim(ndim, wrapped, ragged_dim)
  43. if convert_to_inner_dim
  44. else wrapped
  45. )
  46. if allow_batch_dim:
  47. # Need to disambiguate whether we're operating on the batch dim or not.
  48. # Operating on dim=1 -> dim=0 after the inner dim conversion.
  49. operating_on_batch = wrapped == 0
  50. return (ret, operating_on_batch)
  51. return ret
  52. def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
  53. """
  54. For NestedTensor operators,
  55. wraps dimensions to non-negative values,
  56. and returns metadata related to reduction dimension(s).
  57. """
  58. from torch._prims_common import canonicalize_dims
  59. assert isinstance(dims, (tuple, list)), (
  60. f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
  61. )
  62. wrapped_dims = [
  63. canonicalize_dims(ndim, d) for d in dims
  64. ] # convert all indices to non-negative values
  65. operate_on_batch = 0 in wrapped_dims
  66. operate_on_ragged = ragged_idx in wrapped_dims
  67. operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
  68. # ensure no duplicates, which can result from both batch and ragged mapping to 0
  69. outer_to_inner_dim = tuple(
  70. dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
  71. )
  72. return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
  73. def check_schema(schema_str: str, func, *args, **kwargs) -> None:
  74. named_arg_types = schema_str.split(", ")
  75. num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
  76. min_args = len(named_arg_types) - num_optional_args
  77. # special case: ellipses allows for any number of unchecked args at the end
  78. if named_arg_types[-1] == "...":
  79. named_arg_types = named_arg_types[:-1]
  80. else:
  81. if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
  82. raise ValueError(
  83. f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
  84. f"arguments and at most {len(named_arg_types)} arguments, but got: "
  85. f"{len(args)} arguments"
  86. )
  87. arg_type_check_fns = {
  88. "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
  89. "jt": lambda x: isinstance(x, NestedTensor)
  90. and x._lengths is None
  91. and x._ragged_idx == 1, # ops with "jt" require contiguous JT only
  92. "jt_all": lambda x: isinstance(
  93. x, NestedTensor
  94. ), # ops with "jt_all" can accept all kinds of JT
  95. "any": lambda x: True,
  96. }
  97. for i, named_arg_type in enumerate(named_arg_types):
  98. name, arg_type = named_arg_type.split(": ")
  99. is_optional = arg_type.endswith("?")
  100. normalized_arg_type = arg_type[:-1] if is_optional else arg_type
  101. if normalized_arg_type not in arg_type_check_fns.keys():
  102. raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
  103. if i >= len(args):
  104. if not is_optional:
  105. raise ValueError(
  106. f"NestedTensor {func.__name__}({schema_str}) "
  107. f"missing required argument: {name}"
  108. )
  109. continue
  110. _check_fn = arg_type_check_fns[normalized_arg_type]
  111. def check_fn(x, is_optional=is_optional):
  112. if is_optional:
  113. return x is None or _check_fn(x)
  114. else:
  115. return _check_fn(x)
  116. if not check_fn(args[i]):
  117. type_to_desc = {
  118. "t": "tensor",
  119. "t?": "optional tensor",
  120. "jt": "contiguous jagged layout NestedTensor",
  121. "jt_all": "jagged layout NestedTensor",
  122. "any": "<any type>",
  123. }
  124. raise ValueError(
  125. f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
  126. f"{type_to_desc[arg_type]}"
  127. )
  128. def check_ragged_dim_same(
  129. func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
  130. ) -> None:
  131. # Calling into .shape here
  132. if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
  133. raise RuntimeError(
  134. f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
  135. "same exact offsets tensor."
  136. )
  137. # returns True if the raggedness-relevant portions of the NT shape
  138. # match those of the specified size
  139. def raggedness_matches(nt, size):
  140. end = nt._ragged_idx + 1
  141. nt_ragged = nt._size[:end]
  142. size_ragged = size[:end]
  143. return len(nt_ragged) == len(size_ragged) and (
  144. all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
  145. )
  146. def squeeze_leading_ones(t):
  147. # Note: [ Squeezing leading ones ]
  148. #
  149. # Squeeze leading ones from t.
  150. #
  151. # We want:
  152. # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  153. # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
  154. #
  155. # 1) Squeeze extra ones and grab values from NT
  156. # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
  157. # 2) Do dense broadcasting:
  158. # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
  159. # 3) Construct nested tensor
  160. # (sum(*), ?, ?) -> (B, j0, ?, ?)
  161. #
  162. # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
  163. # at step (4) and we would need to update this function to record how
  164. # many ones we unsqueezed.
  165. while t.dim() > 0 and t.shape[0] == 1:
  166. t = t.squeeze(0)
  167. return t
  168. def register_func(tables, aten_ops, schema_str):
  169. if not isinstance(aten_ops, list):
  170. aten_ops = [aten_ops]
  171. if not isinstance(tables, list):
  172. tables = [tables]
  173. def wrapper(func):
  174. for aten_op in aten_ops:
  175. def get_inner(aten_op):
  176. def inner(*args, **kwargs):
  177. check_schema(schema_str, func, *args, **kwargs)
  178. return func(aten_op, *args, **kwargs)
  179. return inner
  180. for table in tables:
  181. table[aten_op] = get_inner(aten_op)
  182. return func
  183. return wrapper
  184. register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
  185. def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
  186. dispatch_func = JAGGED_OPS_TABLE.get(func, None)
  187. if dispatch_func is not None:
  188. return dispatch_func
  189. # Handle pointwise fallbacks
  190. if torch.Tag.pointwise in func.tags:
  191. from torch.fx.experimental.symbolic_shapes import is_nested_int
  192. # No pointwise ops legitimately accept nested int inputs. Without this check,
  193. # they will be incorrectly interpreted as tensors.
  194. # See https://github.com/pytorch/pytorch/issues/138496
  195. for arg in args:
  196. if is_nested_int(arg):
  197. raise RuntimeError(
  198. f"NestedTensor {func.__name__}: invalid argument {arg}"
  199. )
  200. # Assume there aren't additional tensors that aren't the "unary/binary" args
  201. num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
  202. if num_tensor_args == 1:
  203. # Build up the check schema string. The first tensor arg is assumed to be
  204. # an NJT and other args are sent through as-is.
  205. schema_parts = []
  206. for arg in func._schema.arguments:
  207. if isinstance(arg.type, torch.TensorType):
  208. schema_parts.append(f"{arg.name}: jt_all")
  209. break
  210. else:
  211. schema_parts.append(f"{arg.name}: any")
  212. schema_parts.append("...")
  213. check_schema_str = ", ".join(schema_parts)
  214. check_schema(check_schema_str, func, *args, **kwargs)
  215. return functools.partial(jagged_unary_pointwise, func)
  216. elif num_tensor_args == 2:
  217. check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
  218. return functools.partial(jagged_binary_pointwise, func)
  219. return None
  220. def extract_kwargs(arg):
  221. kwargs = {
  222. "offsets": arg.offsets(),
  223. "lengths": arg.lengths(),
  224. "_metadata_cache": arg._metadata_cache,
  225. "_ragged_idx": arg._ragged_idx,
  226. }
  227. return kwargs
  228. def jagged_unary_pointwise(func, *args, **kwargs):
  229. # assume if we get here that there is a single NJT input in the args
  230. njt = next(arg for arg in args if isinstance(arg, NestedTensor))
  231. return NestedTensor(
  232. func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
  233. **extract_kwargs(njt),
  234. )
  235. def jagged_binary_pointwise(func, *args, **kwargs):
  236. a, b = args[0], args[1]
  237. assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
  238. mismatch_error_msg = (
  239. "cannot call binary pointwise function {} with inputs of shapes {} and {}"
  240. )
  241. # a is NT, b is NT
  242. if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
  243. # ex: (B, j0, D) + (B, j0, D)
  244. # ex: (B, j0, D) + (B, j0, 1)
  245. if raggedness_matches(a, b._size):
  246. return NestedTensor(
  247. func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
  248. )
  249. raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
  250. # either a is NT or b is NT at this point
  251. a_is_nt = isinstance(a, NestedTensor)
  252. extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
  253. # === Handle broadcasting across the batch / ragged dims ===
  254. # Easy case: take advantage of pre-existing broadcasting logic
  255. # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
  256. # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
  257. # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  258. nt, t = (a, b) if a_is_nt else (b, a)
  259. # See Note: [ Squeezing leading ones ]
  260. if t.dim() > nt.dim():
  261. raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
  262. t_squeezed = squeeze_leading_ones(t)
  263. if nt.dim() >= t_squeezed.dim() + 2:
  264. lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
  265. return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
  266. # Harder case: do manual broadcasting when NT dim == non-NT dim
  267. # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
  268. if a.dim() == b.dim():
  269. # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
  270. # be (B, j0, D_0, D_1) but not yet supported
  271. if a.shape[0] != b.shape[0]:
  272. raise RuntimeError(
  273. mismatch_error_msg.format(func.__name__, a.shape, b.shape)
  274. )
  275. from .nested_tensor import nested_from_padded
  276. # handle broadcasting via padded dense -> jagged conversion
  277. min_seqlen = nt._maybe_min_seqlen
  278. max_seqlen = nt._maybe_max_seqlen
  279. padded_max_S = max_seqlen
  280. total_L = nt._values.shape[nt._ragged_idx - 1]
  281. if padded_max_S is None:
  282. # use upper bound on max seqlen if it's not present
  283. padded_max_S = total_L
  284. # convert dense tensor -> jagged
  285. t = t.expand(
  286. [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)]
  287. )
  288. t_as_nt = nested_from_padded(
  289. t,
  290. offsets=nt._offsets,
  291. ragged_idx=nt._ragged_idx,
  292. sum_S=total_L,
  293. min_seqlen=min_seqlen,
  294. max_seqlen=max_seqlen,
  295. )
  296. # function call with two NJTs
  297. lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt)
  298. return func(lhs, rhs, *args[2:], **kwargs)
  299. # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
  300. # that ragged dim is wrt left-most batch dim
  301. raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
  302. def jagged_torch_function(func, *args, **kwargs):
  303. # SDPA has special kernels that handle nested tensors.
  304. # Dispatch to the correct implementation here
  305. if func is torch._C._nn.scaled_dot_product_attention:
  306. return jagged_scaled_dot_product_attention(*args, **kwargs)
  307. if func.__name__ == "apply_":
  308. func(args[0]._values, *args[1:], **kwargs)
  309. return args[0]
  310. # Handle flatten() here because it's CompositeImplicit.
  311. if func.__name__ == "flatten":
  312. def _flatten_sig(input, start_dim=0, end_dim=-1):
  313. pass
  314. _, new_kwargs = normalize_function( # type: ignore[misc]
  315. _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  316. )
  317. inp = new_kwargs.pop("input")
  318. # NB: stay in outer dim space because we're going to redispatch on a NT input
  319. start_dim = _wrap_jagged_dim(
  320. inp.dim(),
  321. new_kwargs["start_dim"],
  322. inp._ragged_idx,
  323. "flatten",
  324. convert_to_inner_dim=False,
  325. )
  326. end_dim = _wrap_jagged_dim(
  327. inp.dim(),
  328. new_kwargs["end_dim"],
  329. inp._ragged_idx,
  330. "flatten",
  331. convert_to_inner_dim=False,
  332. )
  333. if start_dim == end_dim:
  334. return inp
  335. product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
  336. new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
  337. return inp.reshape(*new_shape)
  338. # Handle nested-specific input validation for CompositeImplicit rms_norm
  339. if func.__name__ == "rms_norm":
  340. def _rms_norm_sig(input, normalized_shape, weight=None, eps=None):
  341. pass
  342. _, new_kwargs = normalize_function( # type: ignore[misc]
  343. _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  344. )
  345. inp = new_kwargs.pop("input")
  346. normalized_shape = new_kwargs.pop("normalized_shape")
  347. # can't normalize over the ragged dim (yet)
  348. max_normalizable = inp.dim() - inp._ragged_idx - 1
  349. if len(normalized_shape) > max_normalizable:
  350. raise ValueError(
  351. "rms_norm(): Normalization over the ragged dim not supported for nested tensors"
  352. )
  353. with torch._C.DisableTorchFunctionSubclass():
  354. return func(*args, **kwargs)
  355. raise NotImplementedError(func)
  356. @register_jagged_func(
  357. [
  358. torch.ops.aten.is_non_overlapping_and_dense.default,
  359. torch.ops.aten.sym_size.default,
  360. torch.ops.aten.dim.default,
  361. torch.ops.aten.numel.default,
  362. torch.ops.aten.sym_numel.default,
  363. torch.ops.aten.sym_stride.default,
  364. torch.ops.aten.sym_storage_offset.default,
  365. ],
  366. "self: jt_all",
  367. )
  368. def tensor_attr_supported_getter(func, *args, **kwargs):
  369. if func == torch.ops.aten.is_non_overlapping_and_dense.default:
  370. return False
  371. if func == torch.ops.aten.sym_size.default:
  372. return args[0]._size
  373. if func == torch.ops.aten.dim.default:
  374. return len(args[0]._size)
  375. if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
  376. if args[0]._lengths is not None:
  377. return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
  378. return args[0]._values.numel()
  379. if func == torch.ops.aten.sym_stride.default:
  380. return args[0]._strides
  381. if func == torch.ops.aten.sym_storage_offset.default:
  382. return args[0]._values.storage_offset()
  383. @register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
  384. def prim_layout_default(func, *args, **kwargs):
  385. return torch.jagged
  386. @register_jagged_func(
  387. [torch.ops.aten.size.default],
  388. "self: jt_all",
  389. )
  390. def tensor_attr_unsupported_getter(func, *args, **kwargs):
  391. if func == torch.ops.aten.size.default:
  392. raise RuntimeError(
  393. "NestedTensor does not support directly calling torch.ops.aten.size; "
  394. "please use `nested_tensor.size()` instead."
  395. )
  396. @register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
  397. def is_contiguous_general(func, *args, **kwargs):
  398. from torch._prims_common import is_contiguous_for_memory_format
  399. _, new_kwargs = normalize_function( # type: ignore[misc]
  400. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  401. )
  402. inp = new_kwargs.pop("input")
  403. # If created from narrow() check for lengths
  404. if inp.lengths() is not None:
  405. return False
  406. new_kwargs["memory_format"] = new_kwargs.get(
  407. "memory_format", torch.contiguous_format
  408. )
  409. if new_kwargs["memory_format"] == torch.preserve_format:
  410. return True
  411. return is_contiguous_for_memory_format(inp._values, **new_kwargs)
  412. register_jagged_func(
  413. torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
  414. )(is_contiguous_general)
  415. @register_jagged_func(
  416. torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
  417. )
  418. def sym_is_contiguous_general(func, *args, **kwargs):
  419. _, new_kwargs = normalize_function( # type: ignore[misc]
  420. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  421. )
  422. inp = new_kwargs.pop("input")
  423. # If created from narrow() check for lengths
  424. if inp.lengths() is not None:
  425. return False
  426. new_kwargs["memory_format"] = new_kwargs.get(
  427. "memory_format", torch.contiguous_format
  428. )
  429. if new_kwargs["memory_format"] == torch.preserve_format:
  430. return True
  431. return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
  432. @register_jagged_func(
  433. torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
  434. )
  435. def clone_default(func, *args, **kwargs):
  436. _, new_kwargs = normalize_function( # type: ignore[misc]
  437. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  438. )
  439. inp = new_kwargs.pop("input")
  440. new_meta = extract_kwargs(inp)
  441. if inp._lengths is not None:
  442. if new_kwargs["memory_format"] == torch.contiguous_format:
  443. # need to copy to remove "holes" non-contiguity / lengths metadata
  444. # TODO: write a kernel for this
  445. from .nested_tensor import jagged_from_list
  446. # TODO: We probably want the output to have the same ragged structure / nested int.
  447. assert inp._ragged_idx == 1, (
  448. "NJT with ragged_idx != 1 not supported for contiguous clone"
  449. )
  450. contig, _ = jagged_from_list(inp.unbind(), offsets=None)
  451. return contig
  452. return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
  453. @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
  454. def linear_default(func, *args, **kwargs):
  455. _, new_kwargs = normalize_function( # type: ignore[misc]
  456. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  457. )
  458. inp = new_kwargs.pop("input")
  459. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  460. @register_jagged_func(
  461. torch.ops.aten.linear_backward.default,
  462. "self: jt, grad_output: jt, weight: t, output_mask: any",
  463. )
  464. def linear_backward_default(func, *args, **kwargs):
  465. _, new_kwargs = normalize_function( # type: ignore[misc]
  466. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  467. )
  468. inp = new_kwargs.pop("input")
  469. grad_output = new_kwargs.pop("grad_output")
  470. weight = new_kwargs.pop("weight")
  471. output_mask = new_kwargs.pop("output_mask")
  472. ds, dw, db = None, None, None
  473. check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
  474. if output_mask[0]:
  475. ds = NestedTensor(
  476. torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
  477. )
  478. if output_mask[1]:
  479. # NB: Fold dims of values for input and grad_output to treat them as 2D. This
  480. # trick avoids materializing large intermediates and immediately reducing over
  481. # them via sum(). This is equivalent to computing:
  482. # torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
  483. # and then summing over the leading dimensions to get a 2D weight grad.
  484. grad_2d = grad_output._values.reshape(-1, weight.size(0))
  485. input_2d = inp._values.reshape(-1, weight.size(1))
  486. dw = torch.matmul(grad_2d.t(), input_2d)
  487. if output_mask[2]:
  488. # Sum over all but the last dim to get a 1D bias grad. We cannot
  489. # rely on the autograd engine to reduce for us, because returning a
  490. # tensor aliasing the input would violate the aten signature annotation
  491. reduce_dims = tuple(range(grad_output._values.ndim - 1))
  492. if reduce_dims == ():
  493. db = grad_output._values.clone()
  494. else:
  495. db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
  496. return (ds, dw, db)
  497. @register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
  498. def to_dtype(func, *args, **kwargs):
  499. _, new_kwargs = normalize_function( # type: ignore[misc]
  500. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  501. )
  502. inp = new_kwargs.pop("input")
  503. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  504. @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
  505. def to_copy_default(func, *args, **kwargs):
  506. from .nested_tensor import _tensor_symint_registry
  507. _, new_kwargs = normalize_function( # type: ignore[misc]
  508. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  509. )
  510. inp = new_kwargs.pop("input")
  511. # don't change layout
  512. new_kwargs.pop("layout")
  513. new_values = func(inp._values, **new_kwargs)
  514. new_offsets = inp._offsets.to(device=new_values.device)
  515. new_lengths = None
  516. if inp._lengths is not None:
  517. new_lengths = inp._lengths.to(device=new_values.device)
  518. from torch._subclasses.fake_tensor import FakeTensor
  519. from torch._subclasses.functional_tensor import (
  520. FunctionalTensor,
  521. mb_unwrap_functional_tensor,
  522. )
  523. ragged_source = inp._offsets if inp._lengths is None else inp._lengths
  524. new_thing = new_offsets if new_lengths is None else new_lengths
  525. if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
  526. # Temporary hack until we have the union find
  527. tgt = mb_unwrap_functional_tensor(new_thing)
  528. src = mb_unwrap_functional_tensor(ragged_source)
  529. tgt.nested_int_memo = src.nested_int_memo
  530. else:
  531. _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
  532. inp_kwargs = extract_kwargs(inp)
  533. inp_kwargs["offsets"] = new_offsets
  534. inp_kwargs["lengths"] = new_lengths
  535. output = NestedTensor(new_values, **inp_kwargs)
  536. return output
  537. @register_jagged_func(
  538. torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
  539. )
  540. def copy_default(func, *args, **kwargs):
  541. _, new_kwargs = normalize_function( # type: ignore[misc]
  542. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  543. )
  544. inp = new_kwargs.pop("input")
  545. src = new_kwargs.pop("src")
  546. if inp._size != src._size:
  547. # try to recursively copy_ on unbound components to get around nested int mismatch
  548. # TODO: eventually do a direct copy when this is possible
  549. inp_comps = inp.unbind()
  550. inp_comp_shapes = [c.shape for c in inp_comps]
  551. src_comps = src.unbind()
  552. src_comp_shapes = [c.shape for c in src_comps]
  553. if inp_comp_shapes != src_comp_shapes:
  554. raise RuntimeError(
  555. "copy_(): expected compatible input and src shapes, but got: "
  556. f"{inp.shape} and {src.shape}"
  557. )
  558. for inp_comp, src_comp in zip(inp_comps, src_comps):
  559. inp_comp.copy_(src_comp)
  560. # AOTD allows mutations of inputs only, (not views of the inputs).
  561. # NJT.values() returns _values.detach() to workaround some issues.
  562. # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
  563. # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable.
  564. inp._values.copy_(src._values)
  565. return inp
  566. register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
  567. jagged_unary_pointwise
  568. )
  569. @register_jagged_func(
  570. [
  571. torch.ops.aten.empty_like.default,
  572. torch.ops.aten.ones_like.default,
  573. torch.ops.aten.zeros_like.default,
  574. torch.ops.aten.rand_like.default,
  575. torch.ops.aten.randn_like.default,
  576. ],
  577. "self: jt_all",
  578. )
  579. def like_factory_default(func, *args, **kwargs):
  580. _, new_kwargs = normalize_function( # type: ignore[misc]
  581. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  582. )
  583. inp = new_kwargs.pop("input")
  584. # Default layout is technically torch.strided but only jagged is supported here.
  585. # Rather than force users to specify the layout, assume jagged.
  586. # This should be set to strided for redispatching on values.
  587. new_kwargs["layout"] = torch.strided
  588. new_values = func(inp._values, **new_kwargs)
  589. new_offsets = inp._offsets.to(device=new_values.device)
  590. new_lengths = None
  591. if inp._lengths is not None:
  592. new_lengths = inp._lengths.to(device=new_values.device)
  593. output_kwargs = extract_kwargs(inp)
  594. if "offsets" in output_kwargs:
  595. output_kwargs["offsets"] = new_offsets
  596. if "lengths" in output_kwargs:
  597. output_kwargs["lengths"] = new_lengths
  598. if inp.device != new_values.device:
  599. # Update the nested int registry to indicate that the ragged structure is the same
  600. # between the two offsets / lengths on different devices.
  601. from torch._subclasses.fake_tensor import FakeTensor
  602. from torch._subclasses.functional_tensor import (
  603. FunctionalTensor,
  604. mb_unwrap_functional_tensor,
  605. )
  606. from .nested_tensor import _tensor_symint_registry
  607. ragged_source = inp._offsets if inp._lengths is None else inp._lengths
  608. new_thing = new_offsets if new_lengths is None else new_lengths
  609. if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
  610. # Temporary hack until we have the union find
  611. tgt = mb_unwrap_functional_tensor(new_thing)
  612. src = mb_unwrap_functional_tensor(ragged_source)
  613. tgt.nested_int_memo = src.nested_int_memo
  614. else:
  615. _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
  616. return NestedTensor(new_values, **output_kwargs)
  617. register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
  618. like_factory_default
  619. )
  620. register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
  621. like_factory_default
  622. )
  623. register_jagged_func(
  624. torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
  625. )(like_factory_default)
  626. @register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
  627. def zero__default(func, *args, **kwargs):
  628. _, new_kwargs = normalize_function( # type: ignore[misc]
  629. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  630. )
  631. inp = new_kwargs.pop("input")
  632. func(inp._values)
  633. return inp
  634. @register_jagged_func(
  635. torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
  636. )
  637. def _softmax_default(func, *args, **kwargs):
  638. _, new_kwargs = normalize_function( # type: ignore[misc]
  639. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  640. )
  641. if isinstance(new_kwargs["dim"], tuple):
  642. raise RuntimeError(
  643. "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
  644. )
  645. inp = new_kwargs.pop("input")
  646. (
  647. new_kwargs["dim"],
  648. reduce_on_batch,
  649. reduce_on_ragged,
  650. _reduce_on_non_batch,
  651. ) = _wrap_jagged_dims(
  652. inp.dim(),
  653. (new_kwargs["dim"],),
  654. "softmax",
  655. inp._ragged_idx,
  656. )
  657. if reduce_on_batch:
  658. raise RuntimeError(
  659. "softmax(): not supported when reducing across the batch dimension for NestedTensor"
  660. )
  661. if reduce_on_ragged and inp._ragged_idx > 1:
  662. raise RuntimeError(
  663. "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
  664. )
  665. if reduce_on_ragged and inp._lengths is not None:
  666. raise RuntimeError(
  667. "softmax(): not supported where lengths is not None "
  668. + "if reducing across the ragged dimension for NestedTensor"
  669. )
  670. new_kwargs["dim"] = new_kwargs["dim"][
  671. 0
  672. ] # torch.softmax takes in the reduction dimension as an integer
  673. if reduce_on_ragged:
  674. padded_softmax_values = torch.nn.functional.softmax(
  675. torch.ops.aten._jagged_to_padded_dense_forward(
  676. inp._values.reshape(
  677. inp._values.shape[0], -1
  678. ), # values are required to be 2D tensors for j2pd
  679. [inp._offsets],
  680. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  681. padding_value=float("-inf"), # e^-inf = 0
  682. ),
  683. dim=inp._ragged_idx,
  684. )
  685. softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
  686. padded_softmax_values,
  687. [inp._offsets],
  688. total_L=inp._values.shape[
  689. 0
  690. ], # providing this parameter helps avoid a GPU/CPU sync
  691. ).reshape(
  692. -1, *inp._values.shape[1:]
  693. ) # expand softmax_values back to original shape (inp._values.shape)
  694. return NestedTensor(softmax_values, **extract_kwargs(inp))
  695. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  696. @register_jagged_func(
  697. torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any"
  698. )
  699. def _log_softmax_default(func, *args, **kwargs):
  700. _, new_kwargs = normalize_function( # type: ignore[misc]
  701. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  702. )
  703. if isinstance(new_kwargs["dim"], tuple):
  704. raise RuntimeError(
  705. "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
  706. )
  707. inp = new_kwargs.pop("input")
  708. (
  709. new_kwargs["dim"],
  710. reduce_on_batch,
  711. reduce_on_ragged,
  712. _reduce_on_non_batch,
  713. ) = _wrap_jagged_dims(
  714. inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx
  715. )
  716. if reduce_on_batch:
  717. raise RuntimeError(
  718. "log_softmax(): not supported when reducing across the batch dimension for NestedTensor"
  719. )
  720. if reduce_on_ragged:
  721. raise RuntimeError(
  722. "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor"
  723. )
  724. # torch.log_softmax takes in the reduction dimension as an integer
  725. new_kwargs["dim"] = new_kwargs["dim"][0]
  726. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  727. @register_jagged_func(
  728. torch.ops.aten._softmax_backward_data.default,
  729. "grad_output: jt, output: jt, dim: any, input_dtype: any",
  730. )
  731. def _softmax_backward(func, *args, **kwargs):
  732. _, new_kwargs = normalize_function( # type: ignore[misc]
  733. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  734. )
  735. grad_out = new_kwargs.pop("grad_output")
  736. output = new_kwargs.pop("output")
  737. return NestedTensor(
  738. func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
  739. )
  740. @register_jagged_func(
  741. torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
  742. )
  743. def native_dropout_default(func, *args, **kwargs):
  744. _, new_kwargs = normalize_function( # type: ignore[misc]
  745. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  746. )
  747. inp = new_kwargs.pop("input")
  748. out1, out2 = func(inp._values, **new_kwargs)
  749. return (
  750. NestedTensor(out1, **extract_kwargs(inp)),
  751. NestedTensor(out2, **extract_kwargs(inp)),
  752. )
  753. @register_jagged_func(
  754. torch.ops.aten.native_dropout_backward.default,
  755. "grad_output: jt, mask: jt, scale: any",
  756. )
  757. def native_dropout_backward_default(func, *args, **kwargs):
  758. _, new_kwargs = normalize_function( # type: ignore[misc]
  759. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  760. )
  761. grad_output = new_kwargs.pop("grad_output")
  762. mask = new_kwargs.pop("mask")
  763. return NestedTensor(
  764. func(grad_output._values, mask._values, **new_kwargs),
  765. **extract_kwargs(grad_output),
  766. )
  767. @register_jagged_func(
  768. torch.ops.aten.prod.dim_int,
  769. "self: jt_all, dim: any, keepdim: any?, dtype: any?",
  770. )
  771. def prod_dim_int(func, *args, **kwargs):
  772. return _apply_reduction(func, "prod", 1, *args, **kwargs)
  773. @register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?")
  774. def prod_default(func, *args, **kwargs):
  775. _, new_kwargs = normalize_function( # type: ignore[misc]
  776. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  777. )
  778. inp = new_kwargs.pop("input")
  779. return func(inp._values, **new_kwargs)
  780. @register_jagged_func(
  781. torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?"
  782. )
  783. def split_tensor(func, *args, **kwargs):
  784. _, new_kwargs = normalize_function( # type: ignore[misc]
  785. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  786. )
  787. inp = new_kwargs.pop("input")
  788. new_kwargs["dim"] = _wrap_jagged_dim(
  789. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
  790. )
  791. return tuple(
  792. NestedTensor(values=x, **extract_kwargs(inp))
  793. for x in func(inp._values, **new_kwargs)
  794. )
  795. @register_jagged_func(
  796. torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?"
  797. )
  798. def split_with_sizes_default(func, *args, **kwargs):
  799. _, new_kwargs = normalize_function( # type: ignore[misc]
  800. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  801. )
  802. inp = new_kwargs.pop("input")
  803. new_kwargs["dim"] = _wrap_jagged_dim(
  804. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
  805. )
  806. return [
  807. NestedTensor(values=x, **extract_kwargs(inp))
  808. for x in func(inp._values, **new_kwargs)
  809. ]
  810. @register_jagged_func(
  811. torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
  812. )
  813. def narrow(func, *args, **kwargs):
  814. _, new_kwargs = normalize_function( # type: ignore[misc]
  815. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  816. )
  817. inp = new_kwargs.pop("input")
  818. dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
  819. values = func(
  820. inp._values,
  821. dim=dim,
  822. start=new_kwargs["start"],
  823. length=new_kwargs["length"],
  824. )
  825. return NestedTensor(values, **extract_kwargs(inp))
  826. @register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
  827. def chunk_default(func, *args, **kwargs):
  828. _, new_kwargs = normalize_function( # type: ignore[misc]
  829. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  830. )
  831. inp = new_kwargs.pop("input")
  832. new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
  833. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
  834. )
  835. if operating_on_batch:
  836. chunks = new_kwargs["chunks"]
  837. # get _offsets of the chunks
  838. lengths = inp._offsets.diff()
  839. chunked_lengths = lengths.chunk(chunks)
  840. chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
  841. chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] # type: ignore[arg-type]
  842. nested_kwargs = [
  843. {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
  844. for per_offsets in chunked_offsets
  845. ]
  846. # get _values of the chunks
  847. split_sizes = [x.sum().item() for x in chunked_lengths]
  848. chunk_values = inp._values.split(split_sizes)
  849. # Note that the actual number of chunks returned is not necessarily the same as
  850. # the input number; it can be counter-intuitive, but it matches dense behavior.
  851. return [
  852. NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
  853. for i in range(0, len(chunk_values))
  854. ]
  855. else:
  856. return [
  857. NestedTensor(values=x, **extract_kwargs(inp))
  858. for x in func(inp._values, **new_kwargs)
  859. ]
  860. @register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
  861. def unbind_int(func, *args, **kwargs):
  862. # Note that this specializes on the length of the offsets
  863. _, new_kwargs = normalize_function( # type: ignore[misc]
  864. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  865. )
  866. dim = new_kwargs["dim"]
  867. if dim != 0:
  868. raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
  869. inp = new_kwargs.pop("input")
  870. values = inp.values()
  871. offsets = inp.offsets()
  872. lengths = inp.lengths()
  873. ragged_idx = inp._ragged_idx
  874. def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
  875. # This torch._check and torch._check_is_size are needed for torch.compile
  876. # symbolic shapes processing.
  877. # offsets and lengths are symbolic variables during compilation,
  878. # we guarantee the correct offsets/lengths correspondence:
  879. # sum of lengths <= total ragged_dim_size
  880. # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf))
  881. # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness
  882. # offsets[i] <= ragged_dim_size
  883. lengths_sum = 0
  884. ragged_dim_size = values.shape[ragged_idx - 1]
  885. for i in range(len(_lengths)):
  886. torch._check_is_size(_lengths[i])
  887. torch._check(_lengths[i] <= ragged_dim_size)
  888. lengths_sum += _lengths[i]
  889. if _offsets is not None:
  890. torch._check(
  891. _offsets[i] + _lengths[i] <= ragged_dim_size,
  892. lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension",
  893. )
  894. torch._check(lengths_sum <= ragged_dim_size)
  895. if _offsets is not None:
  896. for i in range(len(_offsets)):
  897. torch._check_is_size(_offsets[i])
  898. torch._check(_offsets[i] <= ragged_dim_size)
  899. if lengths is None:
  900. lengths_scalars = offsets.diff().tolist()
  901. _torch_check(lengths_scalars)
  902. return torch.split(values, lengths_scalars, dim=(ragged_idx - 1))
  903. if ragged_idx <= 0:
  904. raise RuntimeError(
  905. "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
  906. )
  907. lengths_scalars = lengths.tolist()
  908. offsets_scalars = offsets.tolist()
  909. _torch_check(lengths_scalars, offsets_scalars)
  910. return [
  911. torch.narrow(
  912. values,
  913. dim=(ragged_idx - 1),
  914. start=offsets_scalars[i],
  915. length=lengths_scalars[i],
  916. )
  917. for i in range(lengths.shape[0])
  918. ]
  919. @register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
  920. def squeeze_dim(func, *args, **kwargs):
  921. _, new_kwargs = normalize_function( # type: ignore[misc]
  922. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  923. )
  924. inp = new_kwargs.pop("input")
  925. values = inp._values
  926. new_kwargs["dim"] = _wrap_jagged_dim(
  927. len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
  928. )
  929. return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
  930. @register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any")
  931. def unsqueeze_default(func, *args, **kwargs):
  932. _, new_kwargs = normalize_function( # type: ignore[misc]
  933. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  934. )
  935. inp = new_kwargs.pop("input")
  936. values = inp._values
  937. # Account for collapsed jagged dim
  938. dim = new_kwargs["dim"]
  939. new_kwargs["dim"] = _wrap_jagged_dim(
  940. len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True
  941. )
  942. # ragged_idx changes if a dimension is added before it
  943. output_kwargs = extract_kwargs(inp)
  944. if new_kwargs["dim"] <= inp._ragged_idx - 1:
  945. output_kwargs["_ragged_idx"] += 1
  946. return NestedTensor(func(values, **new_kwargs), **output_kwargs)
  947. @register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
  948. def cat_default(func, *args, **kwargs):
  949. _, new_kwargs = normalize_function( # type: ignore[misc]
  950. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  951. )
  952. tensors = new_kwargs.pop("tensors")
  953. # Convert any non-nested to nested
  954. nested = [t for t in tensors if t.is_nested]
  955. assert len(nested) > 0
  956. first = nested[0]
  957. tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
  958. # Account for collapsed jagged dim
  959. dim = new_kwargs["dim"]
  960. new_kwargs["dim"] = _wrap_jagged_dim(
  961. len(first.shape), dim, first._ragged_idx, "cat"
  962. )
  963. return NestedTensor(
  964. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  965. )
  966. @register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any")
  967. def matmul_default(func, *args, **kwargs):
  968. _, new_kwargs = normalize_function( # type: ignore[misc]
  969. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  970. )
  971. inp = new_kwargs.pop("input")
  972. other = new_kwargs.pop("other")
  973. def _unbind_impl(a, b):
  974. return [
  975. func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind())
  976. ]
  977. def _padded_impl(a, b):
  978. if a.is_nested:
  979. nt = a
  980. else:
  981. nt = b
  982. from .nested_tensor import nested_from_padded
  983. min_seqlen = nt._maybe_min_seqlen
  984. max_seqlen = nt._maybe_max_seqlen
  985. padded_max_S = max_seqlen
  986. total_L = nt._values.shape[nt._ragged_idx - 1]
  987. if padded_max_S is None:
  988. # use upper bound on max seqlen if it's not present
  989. padded_max_S = total_L
  990. padded_shape = (
  991. *nt.shape[: nt._ragged_idx],
  992. padded_max_S,
  993. *nt.shape[nt._ragged_idx + 1 :],
  994. )
  995. padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape)
  996. if a.is_nested:
  997. padded_t = func(padded_nt, b)
  998. else:
  999. padded_t = func(a, padded_nt)
  1000. return nested_from_padded(
  1001. padded_t,
  1002. offsets=nt._offsets,
  1003. ragged_idx=nt._ragged_idx,
  1004. sum_S=total_L,
  1005. min_seqlen=min_seqlen,
  1006. max_seqlen=max_seqlen,
  1007. )
  1008. # TODO: Back these with proper kernels (e.g. grouped GEMM)
  1009. # NJT x dense
  1010. if inp.is_nested and not other.is_nested:
  1011. # (B, j1, D) x (B, D, E) => (B, j1, E)
  1012. if (
  1013. inp.dim() >= 3
  1014. and inp.dim() == other.dim()
  1015. and inp._ragged_idx < inp.dim() - 1
  1016. ):
  1017. # convert to padded for this
  1018. return _padded_impl(inp, other)
  1019. # Support broadcasting the dense:
  1020. # (B, j1, D) x (D, E) => (B, j1, E)
  1021. # (B, j1, D, E) x (E, F) => (B, j1, D, F)
  1022. # etc.
  1023. elif (
  1024. other.dim() == 2
  1025. and inp.dim() > other.dim()
  1026. and inp._ragged_idx < inp.dim() - 1
  1027. ):
  1028. return NestedTensor(
  1029. func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
  1030. )
  1031. # Dense x NJT
  1032. elif not inp.is_nested and other.is_nested:
  1033. # (B, D, E) x (B, E, j1) => (B, E, j1)
  1034. if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2:
  1035. # convert to padded for this
  1036. return _padded_impl(inp, other)
  1037. # Support broadcasting the dense:
  1038. # (D, E) x (B, E, j1) => (B, D, j1)
  1039. # (D, E) x (B, E, j1, F) => (B, D, j1, F)
  1040. # etc.
  1041. elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2:
  1042. return NestedTensor(
  1043. func(inp, other._values, **new_kwargs), **extract_kwargs(other)
  1044. )
  1045. # NJT x NJT
  1046. elif inp.is_nested and other.is_nested:
  1047. # Support ragged batch dim:
  1048. # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc.
  1049. if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
  1050. return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
  1051. # Support reducing over ragged with dense output:
  1052. # (B, D, j1) x (B, j1, E) => (B, D, E)
  1053. elif (
  1054. inp.dim() == 3
  1055. and other.dim() == 3
  1056. and inp._ragged_idx == 2
  1057. and other._ragged_idx == 1
  1058. and inp.size(inp._ragged_idx) == other.size(other._ragged_idx)
  1059. ):
  1060. # do unbind for this; can't use padded conversion due to j1 in last dim
  1061. return torch.stack(_unbind_impl(inp, other))
  1062. raise RuntimeError(
  1063. f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
  1064. )
  1065. @register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any")
  1066. def bmm_default(func, *args, **kwargs):
  1067. _, new_kwargs = normalize_function( # type: ignore[misc]
  1068. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1069. )
  1070. inp = new_kwargs.pop("input")
  1071. other = new_kwargs.pop("mat2")
  1072. if inp.dim() != 3:
  1073. raise ValueError("bmm(): input must be 3D")
  1074. if other.dim() != 3:
  1075. raise ValueError("bmm(): mat2 must be 3D")
  1076. return matmul_default(torch.ops.aten.matmul.default, inp, other)
  1077. @register_jagged_func(
  1078. torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?"
  1079. )
  1080. def expand_default(func, *args, **kwargs):
  1081. _, new_kwargs = normalize_function( # type: ignore[misc]
  1082. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1083. )
  1084. inp = new_kwargs.pop("input")
  1085. size = new_kwargs["size"]
  1086. assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
  1087. if not raggedness_matches(inp, size):
  1088. raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
  1089. expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())]
  1090. return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
  1091. @register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
  1092. def expand_as_default(func, *args, **kwargs):
  1093. _, new_kwargs = normalize_function( # type: ignore[misc]
  1094. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1095. )
  1096. inp = new_kwargs.pop("input")
  1097. other = new_kwargs.pop("other")
  1098. return NestedTensor(func(inp, other._values), **extract_kwargs(other))
  1099. @register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any")
  1100. def broadcast_to(func, *args, **kwargs):
  1101. _, new_kwargs = normalize_function( # type: ignore[misc]
  1102. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1103. )
  1104. inp = new_kwargs.pop("input")
  1105. size = new_kwargs.pop("size")
  1106. if len(size) <= inp.dim():
  1107. return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size])
  1108. raise ValueError(
  1109. "broadcast_to(): broadcasting to a higher-dim shape is currently not supported "
  1110. "for nested tensors with the jagged layout"
  1111. )
  1112. @register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any")
  1113. def broadcast_tensors(func, *args, **kwargs):
  1114. _, new_kwargs = normalize_function( # type: ignore[misc]
  1115. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1116. )
  1117. tensors = new_kwargs.pop("tensors")
  1118. if len(tensors) == 0:
  1119. raise ValueError("broadcast_tensors(): expected at least one tensor input")
  1120. if len(tensors) == 1:
  1121. return tensors[0]
  1122. outs = []
  1123. broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors))
  1124. # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible.
  1125. njt = next(t for t in tensors if isinstance(t, NestedTensor))
  1126. for t in tensors:
  1127. if t.is_nested:
  1128. outs.append(t.broadcast_to(broadcast_shape))
  1129. elif t.dim() < len(broadcast_shape):
  1130. outs.append(
  1131. NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt))
  1132. )
  1133. else:
  1134. raise ValueError(
  1135. "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal "
  1136. "or higher dim is not currently supported"
  1137. )
  1138. return tuple(outs)
  1139. @register_jagged_func(
  1140. torch.ops.aten.where.self, "condition: jt_all, self: any, other: any"
  1141. )
  1142. def where_self(func, *args, **kwargs):
  1143. _, new_kwargs = normalize_function( # type: ignore[misc]
  1144. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1145. )
  1146. condition = new_kwargs.pop("condition")
  1147. inp = new_kwargs.pop("input")
  1148. other = new_kwargs.pop("other")
  1149. # if the tensors aren't compatible, broadcast_tensors() will let us know
  1150. condition, inp, other = torch.broadcast_tensors(condition, inp, other)
  1151. return NestedTensor(
  1152. func(condition._values, inp._values, other._values, **new_kwargs),
  1153. **extract_kwargs(condition),
  1154. )
  1155. @register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
  1156. def _pin_memory_default(func, *args, **kwargs):
  1157. _, new_kwargs = normalize_function( # type: ignore[misc]
  1158. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1159. )
  1160. inp = new_kwargs.pop("input")
  1161. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1162. @register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
  1163. def is_pinned_default(func, *args, **kwargs):
  1164. _, new_kwargs = normalize_function( # type: ignore[misc]
  1165. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1166. )
  1167. inp = new_kwargs.pop("input")
  1168. return func(inp._values, **new_kwargs)
  1169. @register_jagged_func(
  1170. torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
  1171. )
  1172. def is_same_size_default(func, *args, **kwargs):
  1173. return args[0]._size == args[1]._size
  1174. def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
  1175. _, new_kwargs = normalize_function( # type: ignore[misc]
  1176. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1177. )
  1178. inp = new_kwargs.pop("input")
  1179. # some ops use dim=None to indicate a full reduction; some use an empty dim list
  1180. full_reduction = new_kwargs["dim"] is None or (
  1181. isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0
  1182. )
  1183. if full_reduction:
  1184. out = func(inp._values, **new_kwargs)
  1185. if new_kwargs.get("keepdim", False):
  1186. if isinstance(out, (tuple, list)):
  1187. # some ops return multiple things; unsqueeze all of them
  1188. out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
  1189. else:
  1190. out = out.unsqueeze(inp._ragged_idx)
  1191. return out
  1192. # some ops support lists of dims; some don't
  1193. dim_to_convert = new_kwargs["dim"]
  1194. is_dimlist = isinstance(new_kwargs["dim"], (tuple, list))
  1195. if not is_dimlist:
  1196. dim_to_convert = [dim_to_convert]
  1197. (
  1198. converted_dim,
  1199. reduce_on_batch,
  1200. reduce_on_ragged,
  1201. reduce_on_non_batch,
  1202. ) = _wrap_jagged_dims(
  1203. inp.dim(),
  1204. dim_to_convert,
  1205. f"{func_name}",
  1206. inp._ragged_idx,
  1207. )
  1208. if not is_dimlist:
  1209. # convert back from list
  1210. converted_dim = converted_dim[0]
  1211. new_kwargs["dim"] = converted_dim
  1212. if reduce_on_ragged and inp._lengths is not None:
  1213. raise RuntimeError(
  1214. f"{func_name}(): reducing across the ragged dimension is not supported "
  1215. "for non-contiguous nested tensors with holes"
  1216. )
  1217. from torch.utils._pytree import tree_map
  1218. # raggedness reduced away --> return dense tensor
  1219. if reduce_on_ragged:
  1220. # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
  1221. if reduce_on_batch:
  1222. # no need to read offsets --> apply sum directly on values
  1223. out = func(inp._values, **new_kwargs)
  1224. if new_kwargs.get("keepdim", False):
  1225. # some ops return multiple things; unsqueeze all of them
  1226. out = tree_map(lambda o: o.unsqueeze(0), out)
  1227. return out
  1228. else:
  1229. # invalid reduction cases: (ragged, non-batch), etc.
  1230. if reduce_on_non_batch:
  1231. raise RuntimeError(
  1232. f"{func_name}(): reducing along a ragged and non-batch dimension "
  1233. "is not supported for nested tensors"
  1234. )
  1235. # reduction cases: (ragged)
  1236. # convert to padded dense and reduce
  1237. new_kwargs.pop("dim")
  1238. dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
  1239. return func(
  1240. inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
  1241. )
  1242. # raggedness preserved --> return nested tensor
  1243. else:
  1244. # invalid reduction cases: (batch), (batch, non-batch), etc.
  1245. if reduce_on_batch:
  1246. raise RuntimeError(
  1247. f"{func_name}(): reducing along the batch dimension but not "
  1248. "the ragged dimension is not supported for nested tensors"
  1249. )
  1250. # reduction cases: (non-batch), (non-batch, non-batch), etc.
  1251. # apply sum directly on values
  1252. out = func(inp._values, **new_kwargs)
  1253. out_kwargs = extract_kwargs(inp)
  1254. if not new_kwargs.get("keepdim", False):
  1255. # dims are reduced away -> ragged_idx of output needs to be reevaluated
  1256. dimlist = (
  1257. new_kwargs["dim"]
  1258. if isinstance(new_kwargs["dim"], (tuple, list))
  1259. else [new_kwargs["dim"]]
  1260. )
  1261. for d in dimlist:
  1262. # adjust for all dims reduced before the ragged dim
  1263. if d < inp._ragged_idx - 1:
  1264. out_kwargs["_ragged_idx"] -= 1
  1265. # some ops return multiple things; wrap each of them as an NJT
  1266. return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
  1267. @register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
  1268. def sum_default(func, *args, **kwargs):
  1269. _, new_kwargs = normalize_function( # type: ignore[misc]
  1270. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1271. )
  1272. inp = new_kwargs.pop("input")
  1273. return func(inp._values, **new_kwargs)
  1274. @register_jagged_func(
  1275. torch.ops.aten.sum.dim_IntList,
  1276. "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
  1277. )
  1278. def sum_dim_IntList(func, *args, **kwargs):
  1279. return _apply_reduction(func, "sum", 0, *args, **kwargs)
  1280. @register_jagged_func(
  1281. torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
  1282. )
  1283. def transpose_int(func, *args, **kwargs):
  1284. _, new_kwargs = normalize_function( # type: ignore[misc]
  1285. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1286. )
  1287. from torch._prims_common import canonicalize_dims
  1288. inp = new_kwargs.pop("input")
  1289. dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
  1290. # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
  1291. # instead of 1, although the internal Flash and mem-effn implementations will
  1292. # use the inputs with raggedness in dim 1.
  1293. if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
  1294. if dim0 == 0 or dim1 == 0:
  1295. raise ValueError(
  1296. "Transpose is not supported on the batch dimension for jagged NT"
  1297. )
  1298. if dim0 == inp._ragged_idx:
  1299. to_dim = dim1
  1300. else:
  1301. to_dim = dim0
  1302. inp_kwargs = extract_kwargs(inp)
  1303. inp_kwargs["_ragged_idx"] = to_dim
  1304. return NestedTensor(
  1305. inp.values().transpose(
  1306. _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
  1307. _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
  1308. ),
  1309. **inp_kwargs,
  1310. )
  1311. new_kwargs["dim0"] = _wrap_jagged_dim(
  1312. inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
  1313. )
  1314. new_kwargs["dim1"] = _wrap_jagged_dim(
  1315. inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
  1316. )
  1317. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1318. @register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
  1319. def permute_default(func, *args, **kwargs):
  1320. _, new_kwargs = normalize_function( # type: ignore[misc]
  1321. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1322. )
  1323. inp = new_kwargs.pop("input")
  1324. dims = new_kwargs.pop("dims")
  1325. inp_kwargs = extract_kwargs(inp)
  1326. inp_dim = len(inp._size)
  1327. # The first two checks are the same as the checks in the normal permute implementation
  1328. if inp_dim != len(dims):
  1329. raise ValueError(
  1330. f"permute(): number of dimensions in the tensor input ({inp_dim}) "
  1331. + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
  1332. )
  1333. from torch._prims_common import canonicalize_dims
  1334. canonicalized_dims = canonicalize_dims(inp_dim, dims)
  1335. if len(canonicalized_dims) != len(set(canonicalized_dims)):
  1336. raise ValueError("permute(): duplicate dims are not allowed.")
  1337. if inp._lengths is not None:
  1338. raise ValueError(
  1339. "permute(): not supported on jagged layout nested tensor with holes"
  1340. )
  1341. if canonicalized_dims[0] != 0:
  1342. raise ValueError(
  1343. "Permute is not supported on the batch dimension for jagged NT"
  1344. )
  1345. inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
  1346. inner_dims = [
  1347. _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
  1348. for dim in canonicalized_dims[1:]
  1349. ]
  1350. new_kwargs["dims"] = inner_dims
  1351. return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
  1352. @register_jagged_func(
  1353. [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
  1354. "self: jt_all, size: any",
  1355. )
  1356. def view_default(func, *args, **kwargs):
  1357. _, new_kwargs = normalize_function( # type: ignore[misc]
  1358. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1359. )
  1360. inp = new_kwargs.pop("input")
  1361. size = new_kwargs.pop("size")
  1362. if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
  1363. raise RuntimeError(
  1364. f"view(): does not support ragged_idx != 1 except when inp._size == size. "
  1365. f"inp._size is ({inp._size}) and size is ({size})."
  1366. )
  1367. # Ensure specified size still includes batch and ragged dims
  1368. if len(size) < 3 or not raggedness_matches(inp, size):
  1369. raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
  1370. # outer size: the size of the NT, e.g. [3, j0, 10]
  1371. # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
  1372. # this function gets inner_size[inner_idx] for a given inner_idx.
  1373. #
  1374. # example: for outer size [a, b, c, j0, d, e, f]
  1375. # assume that j0 is ragged, other are concrete integers
  1376. # and ragged_idx=3
  1377. # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f]
  1378. # therefore:
  1379. # inner_size[0] = outer_size[1]
  1380. # inner_size[1] = outer_size[2]
  1381. # inner_size[0] = inp._values.size(ragged_idx - 1)
  1382. # inner_size[3] = outer_size[4]
  1383. # inner_size[4] = outer_size[5]
  1384. def get_inner_size(inner_idx):
  1385. nonlocal inp, size
  1386. if inner_idx == inp._ragged_idx - 1:
  1387. return inp._values.size(inner_idx)
  1388. else:
  1389. return size[inner_idx + 1]
  1390. inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
  1391. # Preserve inference-mode-ness of input.
  1392. # TODO: Do this for all other views!
  1393. with torch.inference_mode(inp.is_inference()):
  1394. return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
  1395. @register_jagged_func(
  1396. torch.ops.aten.native_layer_norm.default,
  1397. "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
  1398. )
  1399. def native_layer_norm_default(func, *args, **kwargs):
  1400. _, new_kwargs = normalize_function( # type: ignore[misc]
  1401. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1402. )
  1403. inp = new_kwargs.pop("input")
  1404. if inp.dim() <= 2:
  1405. raise RuntimeError(
  1406. "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
  1407. )
  1408. normalized_shape = new_kwargs["normalized_shape"]
  1409. ragged_size = inp.shape[inp._ragged_idx]
  1410. num_dims_not_normalized = inp.dim() - len(normalized_shape)
  1411. if (
  1412. num_dims_not_normalized == 0
  1413. ): # error if trying to normalize over the batch dimension
  1414. raise RuntimeError(
  1415. "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
  1416. )
  1417. if ragged_size in normalized_shape and inp._lengths is not None:
  1418. raise RuntimeError(
  1419. "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
  1420. )
  1421. if (
  1422. ragged_size in normalized_shape
  1423. ): # special handling for normalizing over the ragged dimension
  1424. padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
  1425. inp._values.flatten(
  1426. start_dim=inp._ragged_idx
  1427. ), # _jagged_to_padded_dense_forward requires values to be a 2D tensor
  1428. [inp._offsets],
  1429. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  1430. )
  1431. padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
  1432. torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
  1433. [inp._offsets],
  1434. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  1435. ).expand(
  1436. padded_input.shape
  1437. ) # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
  1438. ragged_lengths = (
  1439. inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
  1440. ) # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
  1441. mean = (
  1442. torch.sum(
  1443. padded_input,
  1444. dim=(1, 2),
  1445. keepdim=True,
  1446. )
  1447. / ragged_lengths
  1448. ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
  1449. padded_normalized = (
  1450. (padded_input - mean) * padded_mask
  1451. ) # mask elements outside of the ragged dimension size for correct variance calculation
  1452. variance = (
  1453. torch.sum(
  1454. torch.square(padded_normalized),
  1455. dim=(1, 2),
  1456. keepdim=True,
  1457. )
  1458. / ragged_lengths
  1459. ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
  1460. std = torch.sqrt(variance + new_kwargs["eps"])
  1461. padded_layer_norm = padded_normalized / std
  1462. jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
  1463. padded_layer_norm,
  1464. [inp._offsets],
  1465. total_L=inp._values.shape[
  1466. 0
  1467. ], # providing this parameter helps avoid a GPU/CPU sync
  1468. ).unflatten(
  1469. -1, inp.shape[inp._ragged_idx + 1 :]
  1470. ) # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
  1471. return (
  1472. NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
  1473. mean,
  1474. std,
  1475. )
  1476. output, mean, std = func(inp._values, **new_kwargs)
  1477. return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
  1478. @register_jagged_func(
  1479. torch.ops.aten.native_layer_norm_backward.default,
  1480. "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
  1481. )
  1482. def native_layer_norm_backward_default(func, *args, **kwargs):
  1483. _, new_kwargs = normalize_function( # type: ignore[misc]
  1484. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1485. )
  1486. grad_out = new_kwargs.pop("grad_out")
  1487. inp = new_kwargs.pop("input")
  1488. d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
  1489. if d_input is None:
  1490. return (None, d_gamma, d_beta)
  1491. return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
  1492. @register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any")
  1493. def select_int(func, *args, **kwargs):
  1494. _, new_kwargs = normalize_function( # type: ignore[misc]
  1495. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1496. )
  1497. inp = new_kwargs.pop("input")
  1498. new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
  1499. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
  1500. )
  1501. # handle batch dim slicing via unbind() for now
  1502. # TODO: make this more efficient
  1503. if operating_on_batch:
  1504. return inp.unbind()[new_kwargs["index"]]
  1505. if inp._lengths is not None:
  1506. raise ValueError(
  1507. "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes"
  1508. )
  1509. # if selecting before the ragged dim, adjust output ragged_idx
  1510. out_kwargs = extract_kwargs(inp)
  1511. if new_kwargs["dim"] < inp._ragged_idx - 1:
  1512. out_kwargs["_ragged_idx"] -= 1
  1513. return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs)
  1514. @register_jagged_func(
  1515. torch.ops.aten.slice.Tensor,
  1516. "self: jt, dim: any?, start: any?, end: any?, step: any?",
  1517. )
  1518. def slice_tensor(func, *args, **kwargs):
  1519. _, new_kwargs = normalize_function( # type: ignore[misc]
  1520. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1521. )
  1522. inp = new_kwargs.pop("input")
  1523. new_kwargs["dim"] = _wrap_jagged_dim(
  1524. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
  1525. )
  1526. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1527. @register_jagged_func(
  1528. torch.ops.aten.index_put.default,
  1529. "input: jt_all, indices: any, values: t, accumulate: any?",
  1530. )
  1531. @register_jagged_func(
  1532. torch.ops.aten.index_put_.default,
  1533. "input: jt_all, indices: any, values: t, accumulate: any?",
  1534. )
  1535. def index_put_(func, *args, **kwargs):
  1536. _, new_kwargs = normalize_function( # type: ignore[misc]
  1537. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1538. )
  1539. inp: NestedTensor = new_kwargs.pop("input")
  1540. # For index_put_ to work, we add together the indices of the ragged dimension
  1541. # and the batch dimension, adding the offsets of each ragged dimension to its
  1542. # indices
  1543. indices = new_kwargs.pop("indices")
  1544. assert len(indices) <= inp.dim()
  1545. if len(indices) < inp._ragged_idx + 1:
  1546. if not inp.is_contiguous():
  1547. raise RuntimeError(
  1548. "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
  1549. )
  1550. # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
  1551. from .nested_tensor import nested_from_padded
  1552. min_seqlen = inp._maybe_min_seqlen
  1553. max_seqlen = inp._maybe_max_seqlen
  1554. padded_max_S = max_seqlen
  1555. total_L = inp._values.shape[inp._ragged_idx - 1]
  1556. if padded_max_S is None:
  1557. # use upper bound on max seqlen if it's not present
  1558. padded_max_S = total_L
  1559. padded_shape = (
  1560. *inp.shape[: inp._ragged_idx],
  1561. padded_max_S,
  1562. *inp.shape[inp._ragged_idx + 1 :],
  1563. )
  1564. padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
  1565. new_njt = nested_from_padded(
  1566. func(padded_inp, indices, **new_kwargs),
  1567. offsets=inp._offsets,
  1568. ragged_idx=inp._ragged_idx,
  1569. sum_S=total_L,
  1570. min_seqlen=min_seqlen,
  1571. max_seqlen=max_seqlen,
  1572. )
  1573. if func == torch.ops.aten.index_put_.default:
  1574. inp._values.copy_(new_njt.values())
  1575. return inp
  1576. return new_njt
  1577. # We can run on the underlying values directly
  1578. # Validate indices
  1579. if inp.lengths() is None:
  1580. lengths = inp.offsets().diff()
  1581. else:
  1582. lengths = inp.lengths()
  1583. torch._assert_async(
  1584. torch.all(indices[inp._ragged_idx] < lengths),
  1585. "Some indices in the ragged dimension are out of bounds!",
  1586. )
  1587. # Recompute indices for _values
  1588. ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
  1589. func_indices = (
  1590. # before ragged dim
  1591. indices[1 : inp._ragged_idx]
  1592. # ragged dim (combined with batch)
  1593. + [ragged_indices]
  1594. # after ragged dim
  1595. + indices[inp._ragged_idx + 1 :]
  1596. )
  1597. if func == torch.ops.aten.index_put_.default:
  1598. inp._values = func(inp._values, func_indices, **new_kwargs)
  1599. return inp
  1600. return NestedTensor(
  1601. func(inp._values, func_indices, **new_kwargs),
  1602. **extract_kwargs(inp),
  1603. )
  1604. @register_jagged_func(
  1605. torch.ops.aten.convolution.default,
  1606. "input: jt, weight: t, bias: t?, stride: any, padding: any, "
  1607. "dilation: any, transposed: any, output_padding: any, groups: any",
  1608. )
  1609. def convolution_default(func, *args, **kwargs):
  1610. _, new_kwargs = normalize_function( # type: ignore[misc]
  1611. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1612. )
  1613. inp = new_kwargs.pop("input")
  1614. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1615. @register_jagged_func(
  1616. torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
  1617. )
  1618. def mean_dim(func, *args, **kwargs):
  1619. _, new_kwargs = normalize_function( # type: ignore[misc]
  1620. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1621. )
  1622. inp = new_kwargs["input"]
  1623. (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims(
  1624. inp.dim(),
  1625. new_kwargs["dim"],
  1626. "mean",
  1627. inp._ragged_idx,
  1628. )
  1629. if reduce_on_ragged and not reduce_on_batch:
  1630. assert not reduce_on_non_batch
  1631. # calculate an intermediate sum and leave the dim in for normalization purposes
  1632. keepdim = new_kwargs["keepdim"]
  1633. new_kwargs["keepdim"] = True
  1634. intermediate_sum = _apply_reduction(
  1635. torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs
  1636. )
  1637. # normalize by sequence lengths
  1638. lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff()
  1639. for _ in range(intermediate_sum.dim() - 1):
  1640. lengths = lengths.unsqueeze(-1)
  1641. out = intermediate_sum / lengths
  1642. if not keepdim:
  1643. out = out.squeeze(inp._ragged_idx)
  1644. return out
  1645. # at this point, we're just redispatching on the values buffer
  1646. # since we expect it to be unused, specify a weird intermediate value to
  1647. # hopefully make errors obvious
  1648. intermediate_value = 0.42
  1649. return _apply_reduction(func, "mean", intermediate_value, **new_kwargs)
  1650. @register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?")
  1651. def mean_default(func, *args, **kwargs):
  1652. _, new_kwargs = normalize_function( # type: ignore[misc]
  1653. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1654. )
  1655. inp = new_kwargs.pop("input")
  1656. return func(inp._values, **new_kwargs)
  1657. @register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?")
  1658. def any_dims(func, *args, **kwargs):
  1659. return _apply_reduction(func, "any", False, *args, **kwargs)
  1660. @register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?")
  1661. def any_dim(func, *args, **kwargs):
  1662. _, new_kwargs = normalize_function( # type: ignore[misc]
  1663. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1664. )
  1665. # wrap dim in list to redispatch to dims overload
  1666. new_kwargs["dim"] = [new_kwargs["dim"]]
  1667. return any_dims(torch.ops.aten.any.dims, **new_kwargs)
  1668. @register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?")
  1669. def all_dims(func, *args, **kwargs):
  1670. return _apply_reduction(func, "all", True, *args, **kwargs)
  1671. @register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?")
  1672. def all_dim(func, *args, **kwargs):
  1673. _, new_kwargs = normalize_function( # type: ignore[misc]
  1674. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1675. )
  1676. # wrap dim in list to redispatch to dims overload
  1677. new_kwargs["dim"] = [new_kwargs["dim"]]
  1678. return all_dims(torch.ops.aten.all.dims, **new_kwargs)
  1679. @register_jagged_func(
  1680. [
  1681. torch.ops.aten.all.default,
  1682. torch.ops.aten.any.default,
  1683. torch.ops.aten.max.default,
  1684. torch.ops.aten.min.default,
  1685. ],
  1686. "self: jt_all",
  1687. )
  1688. def all_any_max_min_default(func, *args, **kwargs):
  1689. _, new_kwargs = normalize_function( # type: ignore[misc]
  1690. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1691. )
  1692. inp = new_kwargs.pop("input")
  1693. return func(inp._values, **new_kwargs)
  1694. @register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
  1695. def min_dim(func, *args, **kwargs):
  1696. _, new_kwargs = normalize_function( # type: ignore[misc]
  1697. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1698. )
  1699. dtype_max = torch.finfo(new_kwargs["input"].dtype).max
  1700. return _apply_reduction(func, "min", dtype_max, *args, **kwargs)
  1701. @register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?")
  1702. def max_dim(func, *args, **kwargs):
  1703. _, new_kwargs = normalize_function( # type: ignore[misc]
  1704. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1705. )
  1706. dtype_min = torch.finfo(new_kwargs["input"].dtype).min
  1707. return _apply_reduction(func, "max", dtype_min, *args, **kwargs)
  1708. @register_jagged_func(
  1709. torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?"
  1710. )
  1711. def amin_default(func, *args, **kwargs):
  1712. _, new_kwargs = normalize_function( # type: ignore[misc]
  1713. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1714. )
  1715. dtype_max = torch.finfo(new_kwargs["input"].dtype).max
  1716. return _apply_reduction(func, "amin", dtype_max, *args, **kwargs)
  1717. @register_jagged_func(
  1718. torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?"
  1719. )
  1720. def amax_default(func, *args, **kwargs):
  1721. _, new_kwargs = normalize_function( # type: ignore[misc]
  1722. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1723. )
  1724. dtype_min = torch.finfo(new_kwargs["input"].dtype).min
  1725. return _apply_reduction(func, "amax", dtype_min, *args, **kwargs)
  1726. @register_jagged_func(
  1727. torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?"
  1728. )
  1729. def argmin_default(func, *args, **kwargs):
  1730. _, new_kwargs = normalize_function( # type: ignore[misc]
  1731. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1732. )
  1733. dtype_max = torch.finfo(new_kwargs["input"].dtype).max
  1734. return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs)
  1735. @register_jagged_func(
  1736. torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?"
  1737. )
  1738. def argmax_default(func, *args, **kwargs):
  1739. _, new_kwargs = normalize_function( # type: ignore[misc]
  1740. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1741. )
  1742. dtype_min = torch.finfo(new_kwargs["input"].dtype).min
  1743. return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
  1744. @register_jagged_func(
  1745. torch.ops.aten.value_selecting_reduction_backward.default,
  1746. "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
  1747. )
  1748. def value_selecting_reduction_backward_default(func, *args, **kwargs):
  1749. from torch.fx.experimental.symbolic_shapes import is_nested_int
  1750. _, new_kwargs = normalize_function( # type: ignore[misc]
  1751. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1752. )
  1753. grad = new_kwargs.pop("grad")
  1754. new_kwargs["grad"] = grad._values
  1755. indices = new_kwargs.pop("indices")
  1756. new_kwargs["indices"] = indices._values
  1757. # should always succeed; sizes should contain a nested int
  1758. ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
  1759. # convert dim -> values-space dim
  1760. new_kwargs["dim"] = _wrap_jagged_dim(
  1761. len(new_kwargs["sizes"]),
  1762. new_kwargs["dim"],
  1763. ragged_idx,
  1764. "value_selecting_reduction_backward",
  1765. )
  1766. # convert saved NJT sizes -> values-space sizes
  1767. sizes = new_kwargs.pop("sizes")
  1768. sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
  1769. sizes = sizes[1:]
  1770. new_kwargs["sizes"] = sizes
  1771. output_kwargs = extract_kwargs(indices)
  1772. output_kwargs["_ragged_idx"] = ragged_idx
  1773. return NestedTensor(func(**new_kwargs), **output_kwargs)
  1774. @register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
  1775. def stack_default(func, *args, **kwargs):
  1776. _, new_kwargs = normalize_function( # type: ignore[misc]
  1777. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1778. )
  1779. # guaranteed this is non-empty if we got here
  1780. tensors = new_kwargs.pop("tensors")
  1781. for t in tensors:
  1782. if not isinstance(t, NestedTensor):
  1783. raise RuntimeError("stack(): expected all nested tensors inputs")
  1784. if t.dim() != tensors[0].dim():
  1785. raise RuntimeError(
  1786. "stack(): expected all nested tensors to have the same dim"
  1787. )
  1788. if not raggedness_matches(t, tensors[0].shape):
  1789. raise RuntimeError(
  1790. "stack(): expected all nested tensors to have the same nested structure"
  1791. )
  1792. new_kwargs["dim"] = _wrap_jagged_dim(
  1793. tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
  1794. )
  1795. return NestedTensor(
  1796. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  1797. )
  1798. @register_jagged_func(
  1799. torch.ops.aten.embedding.default,
  1800. "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
  1801. )
  1802. def embedding_default(func, *args, **kwargs):
  1803. _, new_kwargs = normalize_function( # type: ignore[misc]
  1804. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1805. )
  1806. # guaranteed this is non-empty if we got here
  1807. indices = new_kwargs.pop("indices")
  1808. weight = new_kwargs.pop("weight")
  1809. return NestedTensor(
  1810. func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
  1811. )
  1812. @register_jagged_func(
  1813. torch.ops.aten.embedding_dense_backward.default,
  1814. "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
  1815. )
  1816. def embedding_dense_backward_default(func, *args, **kwargs):
  1817. _, new_kwargs = normalize_function( # type: ignore[misc]
  1818. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1819. )
  1820. indices = new_kwargs.pop("indices")
  1821. grad_output = new_kwargs.pop("grad_output")
  1822. return func(grad_output._values, indices._values, **new_kwargs)
  1823. @register_jagged_func(
  1824. [
  1825. torch.ops.aten.values.default,
  1826. torch.ops.aten._nested_get_values.default,
  1827. ],
  1828. "self: jt_all",
  1829. )
  1830. def values_default(func, *args, **kwargs):
  1831. _, new_kwargs = normalize_function( # type: ignore[misc]
  1832. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1833. )
  1834. inp = new_kwargs.pop("input")
  1835. # TODO: Handle inference mode properly.
  1836. # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
  1837. return inp._values.detach()
  1838. @register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
  1839. def all_default(func, *args, **kwargs):
  1840. _, new_kwargs = normalize_function( # type: ignore[misc]
  1841. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1842. )
  1843. inp = new_kwargs.pop("input")
  1844. return func(inp._values)
  1845. @register_jagged_func(
  1846. torch.ops.aten.to_padded_tensor.default,
  1847. "self: jt_all, padding: any, output_size: any?",
  1848. )
  1849. def to_padded_tensor_default(func, *args, **kwargs):
  1850. _, new_kwargs = normalize_function( # type: ignore[misc]
  1851. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1852. )
  1853. inp = new_kwargs.pop("input")
  1854. if inp._lengths is not None:
  1855. raise RuntimeError(
  1856. "to_padded_tensor(): not supported for nested tensors with holes"
  1857. )
  1858. # TODO: Handle the rest of output_size
  1859. output_size = new_kwargs["output_size"]
  1860. if output_size is not None:
  1861. max_seq_len = output_size[inp._ragged_idx]
  1862. else:
  1863. max_seq_len = (
  1864. inp._max_seqlen
  1865. if inp._max_seqlen_tensor is not None
  1866. else inp._values.size(0)
  1867. )
  1868. # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM
  1869. # kernel so do shape gymnastics if needed
  1870. values = inp.values()
  1871. if inp._ragged_idx > 1:
  1872. values = values.transpose(inp._ragged_idx - 1, 0)
  1873. values_shape = values.shape
  1874. if values.dim() > 2:
  1875. values = values.flatten(start_dim=1)
  1876. elif values.dim() == 1:
  1877. values = values.unsqueeze(-1)
  1878. # NB: The CUDA kernel for jagged -> padded dense conversion does not support
  1879. # integer / bool types; work around this by casting to half.
  1880. is_bool = values.dtype is torch.bool
  1881. if is_bool and values.is_cuda:
  1882. values = values.to(torch.half)
  1883. padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
  1884. values,
  1885. [inp._offsets],
  1886. [max_seq_len],
  1887. new_kwargs["padding"],
  1888. )
  1889. if is_bool and padded_out.is_cuda:
  1890. padded_out = padded_out.to(torch.bool)
  1891. # shape gymnastics part 2
  1892. if len(values_shape) > 2:
  1893. padded_out = padded_out.unflatten(-1, values_shape[1:])
  1894. elif len(values_shape) == 1:
  1895. padded_out = padded_out.squeeze(-1)
  1896. if inp._ragged_idx > 1:
  1897. padded_out = padded_out.transpose(inp._ragged_idx, 1)
  1898. return padded_out
  1899. @register_jagged_func(
  1900. torch.ops.aten._nested_from_padded_tensor.default,
  1901. "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?",
  1902. )
  1903. def _nested_from_padded_tensor_default(func, *args, **kwargs):
  1904. _, new_kwargs = normalize_function( # type: ignore[misc]
  1905. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1906. )
  1907. padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]
  1908. ragged_idx = new_kwargs.get("ragged_idx", 1)
  1909. # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
  1910. # kernel so do shape gymnastics
  1911. if ragged_idx > 1:
  1912. padded = padded.transpose(ragged_idx, 1)
  1913. padded_ragged_dim1_shape = padded.shape
  1914. if padded.dim() > 3:
  1915. padded = padded.flatten(start_dim=2)
  1916. elif padded.dim() < 3:
  1917. padded = padded.unsqueeze(-1)
  1918. # NB: The CUDA kernel for padded dense -> jagged conversion does not support
  1919. # integer / bool types; work around this by casting to half.
  1920. is_bool = padded.dtype is torch.bool
  1921. if is_bool and padded.is_cuda:
  1922. padded = padded.to(torch.half)
  1923. values = torch.ops.aten._padded_dense_to_jagged_forward(
  1924. padded, [offsets], new_kwargs["sum_S"]
  1925. )
  1926. if is_bool and values.is_cuda:
  1927. values = values.to(torch.bool)
  1928. # shape gymnastics part 2
  1929. if len(padded_ragged_dim1_shape) > 3:
  1930. values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
  1931. elif len(padded_ragged_dim1_shape) < 3:
  1932. values = values.squeeze(-1)
  1933. if ragged_idx > 1:
  1934. values = values.transpose(ragged_idx - 1, 0)
  1935. min_seqlen = new_kwargs["min_seqlen"]
  1936. max_seqlen = new_kwargs["max_seqlen"]
  1937. metadata_cache = {}
  1938. if min_seqlen is not None:
  1939. metadata_cache["min_seqlen"] = min_seqlen
  1940. if max_seqlen is not None:
  1941. metadata_cache["max_seqlen"] = max_seqlen
  1942. return NestedTensor(
  1943. values,
  1944. offsets,
  1945. _ragged_idx=ragged_idx,
  1946. _metadata_cache=metadata_cache,
  1947. )
  1948. @register_jagged_func(
  1949. torch.ops.aten._nested_view_from_jagged.default,
  1950. "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
  1951. )
  1952. def _nested_view_from_jagged_default(func, *args, **kwargs):
  1953. _, new_kwargs = normalize_function( # type: ignore[misc]
  1954. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1955. )
  1956. values, offsets, lengths = (
  1957. new_kwargs["input"],
  1958. new_kwargs["offsets"],
  1959. new_kwargs["lengths"],
  1960. )
  1961. ragged_idx = new_kwargs["ragged_idx"]
  1962. min_seqlen = new_kwargs["min_seqlen"]
  1963. max_seqlen = new_kwargs["max_seqlen"]
  1964. metadata_cache = {}
  1965. if min_seqlen is not None:
  1966. metadata_cache["min_seqlen"] = min_seqlen
  1967. if max_seqlen is not None:
  1968. metadata_cache["max_seqlen"] = max_seqlen
  1969. return NestedTensor(
  1970. values,
  1971. offsets,
  1972. lengths=lengths,
  1973. _ragged_idx=ragged_idx,
  1974. _metadata_cache=metadata_cache,
  1975. )
  1976. @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
  1977. def _nested_get_offsets(func, *args, **kwargs):
  1978. _, new_kwargs = normalize_function( # type: ignore[misc]
  1979. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1980. )
  1981. inp = new_kwargs.pop("input")
  1982. return inp._offsets
  1983. @register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
  1984. def _nested_get_lengths(func, *args, **kwargs):
  1985. _, new_kwargs = normalize_function( # type: ignore[misc]
  1986. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1987. )
  1988. inp = new_kwargs.pop("input")
  1989. return inp._lengths
  1990. @register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
  1991. def _nested_get_ragged_idx(func, *args, **kwargs):
  1992. _, new_kwargs = normalize_function( # type: ignore[misc]
  1993. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1994. )
  1995. inp = new_kwargs.pop("input")
  1996. return inp._ragged_idx
  1997. @register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
  1998. def _nested_get_min_seqlen(func, *args, **kwargs):
  1999. _, new_kwargs = normalize_function( # type: ignore[misc]
  2000. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2001. )
  2002. inp = new_kwargs.pop("input")
  2003. return inp._metadata_cache.get("min_seqlen", None)
  2004. @register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
  2005. def _nested_get_max_seqlen(func, *args, **kwargs):
  2006. _, new_kwargs = normalize_function( # type: ignore[misc]
  2007. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2008. )
  2009. inp = new_kwargs.pop("input")
  2010. return inp._metadata_cache.get("max_seqlen", None)
  2011. # If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
  2012. @register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
  2013. def masked_select_default(func, *args, **kwargs):
  2014. _, new_kwargs = normalize_function( # type: ignore[misc]
  2015. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2016. )
  2017. inp = new_kwargs.pop("input")
  2018. mask = new_kwargs.pop("mask")
  2019. if inp.ndim > 2:
  2020. raise RuntimeError("masked_select only support 2-D selections currently")
  2021. elif inp.shape != mask.shape:
  2022. raise RuntimeError(
  2023. f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
  2024. )
  2025. res_values = inp._values.masked_select(mask.values())
  2026. mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0)) # type: ignore[arg-type]
  2027. args = extract_kwargs(inp)
  2028. args["offsets"] = mask_cumsum[inp._offsets]
  2029. return NestedTensor(
  2030. values=res_values,
  2031. **args,
  2032. )
  2033. @register_jagged_func(
  2034. torch.ops.aten._nested_select_backward.default,
  2035. "grad_output: t, self: jt_all, dim: any, index: any",
  2036. )
  2037. def _nested_select_backward_default(func, *args, **kwargs):
  2038. _, new_kwargs = normalize_function( # type: ignore[misc]
  2039. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2040. )
  2041. inp = new_kwargs.pop("input")
  2042. grad_output = new_kwargs.pop("grad_output")
  2043. grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
  2044. grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
  2045. return grad_input
  2046. @register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
  2047. def record_stream_default(func, *args, **kwargs):
  2048. inp = args[0]
  2049. stream = args[1]
  2050. # ensure all components live until stream computation completes
  2051. func(inp._values, stream)
  2052. func(inp._offsets, stream)
  2053. if inp._lengths is not None:
  2054. func(inp._lengths, stream)
  2055. @register_jagged_func(
  2056. [
  2057. torch.ops.aten.new_empty.default,
  2058. torch.ops.aten.new_zeros.default,
  2059. torch.ops.aten.new_ones.default,
  2060. ],
  2061. "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
  2062. )
  2063. def new_empty_default(func, *args, **kwargs):
  2064. _, new_kwargs = normalize_function( # type: ignore[misc]
  2065. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2066. )
  2067. inp = new_kwargs.pop("input")
  2068. if len(new_kwargs["size"]) == 0:
  2069. return func(inp._values, **new_kwargs)
  2070. raise RuntimeError("new_empty() not supported for NJT with shape != ()")
  2071. @register_jagged_func(
  2072. [
  2073. torch.ops.aten.elu_backward.default,
  2074. torch.ops.aten.hardshrink_backward.default,
  2075. torch.ops.aten.hardsigmoid_backward.default,
  2076. torch.ops.aten.hardtanh_backward.default,
  2077. torch.ops.aten.softplus_backward.default,
  2078. torch.ops.aten.softshrink_backward.default,
  2079. ],
  2080. "self: jt_all, ...",
  2081. )
  2082. def activation_backward(func, *args, **kwargs):
  2083. # first NJT arg is expected to be grad_output
  2084. grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
  2085. return NestedTensor(
  2086. func(
  2087. *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
  2088. **kwargs,
  2089. ),
  2090. **extract_kwargs(grad_output),
  2091. )
  2092. @register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any")
  2093. def fill_Scalar(func, *args, **kwargs):
  2094. _, new_kwargs = normalize_function( # type: ignore[misc]
  2095. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2096. )
  2097. inp = new_kwargs.pop("input")
  2098. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  2099. @register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
  2100. def fill__Scalar(func, *args, **kwargs):
  2101. _, new_kwargs = normalize_function( # type: ignore[misc]
  2102. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2103. )
  2104. inp = new_kwargs.pop("input")
  2105. func(inp._values, **new_kwargs)
  2106. return inp
  2107. @register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
  2108. def frexp_Tensor(func, *args, **kwargs):
  2109. _, new_kwargs = normalize_function( # type: ignore[misc]
  2110. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2111. )
  2112. inp = new_kwargs.pop("input")
  2113. output_kwargs = extract_kwargs(inp)
  2114. mantissa, exponent = func(inp._values)
  2115. return NestedTensor(mantissa, **output_kwargs), NestedTensor(
  2116. exponent, **output_kwargs
  2117. )
  2118. @register_jagged_func(
  2119. torch.ops.aten.matmul_backward.default,
  2120. "grad: any, self: any, other: any, mask: any",
  2121. )
  2122. def matmul_backward_default(func, *args, **kwargs):
  2123. _, new_kwargs = normalize_function( # type: ignore[misc]
  2124. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2125. )
  2126. grad = new_kwargs.pop("grad")
  2127. inp = new_kwargs.pop("input")
  2128. other = new_kwargs.pop("other")
  2129. grad_input_mask = new_kwargs.pop("mask")
  2130. if grad is None:
  2131. return (None, None)
  2132. grad_self = None
  2133. if grad_input_mask[0]:
  2134. grad_self = torch.matmul(grad, other.transpose(-1, -2))
  2135. grad_other = None
  2136. if grad_input_mask[1]:
  2137. grad_other = torch.matmul(inp.transpose(-1, -2), grad)
  2138. return (grad_self, grad_other)
  2139. from torch._higher_order_ops.flex_attention import (
  2140. flex_attention as flex_attention_hop,
  2141. flex_attention_backward as flex_attention_backward_hop,
  2142. )
  2143. from torch.fx.graph_module import GraphModule
  2144. @flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc]
  2145. def flex_njt(
  2146. query: torch.Tensor,
  2147. key: torch.Tensor,
  2148. value: torch.Tensor,
  2149. score_mod: Callable,
  2150. block_mask: Tuple,
  2151. scale: float,
  2152. kernel_options: Dict[str, Any],
  2153. score_mod_other_buffers: Tuple = (),
  2154. mask_mod_other_buffers: Tuple = (),
  2155. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  2156. assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
  2157. # TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
  2158. if any(
  2159. isinstance(buf, torch.Tensor) and buf.is_nested
  2160. for buf in score_mod_other_buffers + mask_mod_other_buffers
  2161. ):
  2162. raise RuntimeError(
  2163. "flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
  2164. "currently supported. Please file an issue if this is important to you."
  2165. )
  2166. # Always set them since 0 sized elements are not handled gracefully
  2167. kernel_options = {**kernel_options, "OUTPUT_MAX": True, "OUTPUT_LOGSUMEXP": True}
  2168. # need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
  2169. output = flex_attention_hop(
  2170. query.values().unsqueeze(0),
  2171. key.values().unsqueeze(0),
  2172. value.values().unsqueeze(0),
  2173. score_mod=score_mod,
  2174. block_mask=block_mask,
  2175. scale=scale,
  2176. kernel_options=kernel_options,
  2177. score_mod_other_buffers=score_mod_other_buffers,
  2178. mask_mod_other_buffers=mask_mod_other_buffers,
  2179. )
  2180. # wrap outputs as NJT
  2181. output_njt = torch.nested.nested_tensor_from_jagged(
  2182. output[0].transpose(1, 2).squeeze(0),
  2183. query._offsets, # type: ignore[attr-defined]
  2184. query._lengths, # type: ignore[attr-defined]
  2185. min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
  2186. max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
  2187. ).transpose(1, 2)
  2188. logsumexp_njt = torch.nested.nested_tensor_from_jagged(
  2189. output[1].transpose(1, 2).squeeze(0),
  2190. query._offsets, # type: ignore[attr-defined]
  2191. query._lengths, # type: ignore[attr-defined]
  2192. min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
  2193. max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
  2194. ).transpose(1, 2)
  2195. max_scores_njt = torch.nested.nested_tensor_from_jagged(
  2196. output[2].transpose(1, 2).squeeze(0),
  2197. query._offsets, # type: ignore[attr-defined]
  2198. query._lengths, # type: ignore[attr-defined]
  2199. min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
  2200. max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
  2201. ).transpose(1, 2)
  2202. return (output_njt, logsumexp_njt, max_scores_njt)
  2203. @flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc]
  2204. def flex_njt_backward(
  2205. query: torch.Tensor,
  2206. key: torch.Tensor,
  2207. value: torch.Tensor,
  2208. out: torch.Tensor,
  2209. logsumexp: torch.Tensor,
  2210. grad_out: torch.Tensor,
  2211. grad_logsumexp: torch.Tensor,
  2212. fw_graph: Union[Callable, GraphModule],
  2213. joint_graph: GraphModule,
  2214. block_mask: Tuple,
  2215. scale: float,
  2216. kernel_options: Dict[str, Any],
  2217. score_mod_other_buffers: Tuple = (),
  2218. mask_mod_other_buffers: Tuple = (),
  2219. ) -> Tuple[
  2220. torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
  2221. ]:
  2222. output = flex_attention_backward_hop(
  2223. query.values().unsqueeze(0),
  2224. key.values().unsqueeze(0),
  2225. value.values().unsqueeze(0),
  2226. out=out.values().unsqueeze(0),
  2227. logsumexp=logsumexp.values().unsqueeze(0),
  2228. grad_out=grad_out.values().unsqueeze(0),
  2229. grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
  2230. fw_graph=fw_graph,
  2231. joint_graph=joint_graph,
  2232. block_mask=block_mask,
  2233. scale=scale,
  2234. kernel_options=kernel_options,
  2235. score_mod_other_buffers=score_mod_other_buffers,
  2236. mask_mod_other_buffers=mask_mod_other_buffers,
  2237. )
  2238. # wrap grads as NJTs
  2239. dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
  2240. njt_q_grad = torch.nested.nested_tensor_from_jagged(
  2241. dense_q_grad.transpose(1, 2).squeeze(0),
  2242. query._offsets, # type: ignore[attr-defined]
  2243. query._lengths, # type: ignore[attr-defined]
  2244. min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
  2245. max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
  2246. ).transpose(1, 2)
  2247. njt_k_grad = torch.nested.nested_tensor_from_jagged(
  2248. dense_k_grad.transpose(1, 2).squeeze(0),
  2249. key._offsets, # type: ignore[attr-defined]
  2250. key._lengths, # type: ignore[attr-defined]
  2251. min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined]
  2252. max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined]
  2253. ).transpose(1, 2)
  2254. njt_v_grad = torch.nested.nested_tensor_from_jagged(
  2255. dense_v_grad.transpose(1, 2).squeeze(0),
  2256. value._offsets, # type: ignore[attr-defined]
  2257. value._lengths, # type: ignore[attr-defined]
  2258. min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined]
  2259. max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined]
  2260. ).transpose(1, 2)
  2261. return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
  2262. # Make the dummy available on the C++ side.
  2263. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
  2264. def _nested_get_jagged_dummy(func, *args, **kwargs):
  2265. from torch.nested._internal.nested_tensor import _nt_view_dummy
  2266. return _nt_view_dummy()
  2267. with torch.library._scoped_library("aten", "IMPL") as aten:
  2268. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
  2269. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
  2270. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")