| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424 |
- # mypy: allow-untyped-defs
- import functools
- import math
- import operator
- import sys
- from collections.abc import Callable
- from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
- from typing_extensions import TypeVarTuple, Unpack
- import sympy
- from sympy import S
- from sympy.core import sympify
- from sympy.core.expr import Expr
- from sympy.core.function import Application
- from sympy.core.logic import _torf, fuzzy_and, fuzzy_or
- from sympy.core.numbers import equal_valued
- from sympy.core.operations import LatticeOp, ShortCircuit
- from sympy.core.sorting import ordered
- from sympy.core.traversal import walk
- from sympy.printing.precedence import PRECEDENCE
- from sympy.utilities.iterables import sift
- from .numbers import int_oo
- if TYPE_CHECKING:
- from collections.abc import Iterable
- _T = TypeVar("_T", bound=SupportsFloat)
- _Ts = TypeVarTuple("_Ts")
- # Portions of this file are adapted from the Sympy codebase, which was
- # licensed as follows:
- #
- # Copyright (c) 2006-2023 SymPy Development Team
- #
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # a. Redistributions of source code must retain the above copyright notice,
- # this list of conditions and the following disclaimer.
- # b. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # c. Neither the name of SymPy nor the names of its contributors
- # may be used to endorse or promote products derived from this software
- # without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
- # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
- # DAMAGE.
- __all__ = [
- "FloorDiv",
- "ModularIndexing",
- "Where",
- "PythonMod",
- "Mod",
- "CleanDiv",
- "CeilToInt",
- "FloorToInt",
- "CeilDiv",
- "IntTrueDiv",
- "FloatTrueDiv",
- "LShift",
- "RShift",
- "IsNonOverlappingAndDenseIndicator",
- "TruncToFloat",
- "TruncToInt",
- "RoundToInt",
- "RoundDecimal",
- "ToFloat",
- "FloatPow",
- "PowByNatural",
- "Identity",
- ]
- def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
- # No need to check that two args are not the same, since expr is pr-optimized but we do it anyway.
- return (
- expr.is_Add
- and len(expr._args) == 2
- and expr._args[0].is_symbol
- and expr._args[1].is_symbol
- and expr._args[0] is not expr._args[1]
- )
- def _keep_float(
- f: Callable[[Unpack[_Ts]], _T],
- ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
- @functools.wraps(f)
- def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
- r: Union[_T, sympy.Float] = f(*args)
- if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
- r, sympy.Float
- ):
- r = sympy.Float(float(r))
- return r
- return inner
- def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
- if None in (x, y):
- return None
- return x == y
- def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
- """
- Fast path for sympy.gcd, using a simple factoring strategy.
- We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
- where n is the greatest common integer factor and e is the largest
- syntactic common factor (i.e., common sub-expression) in p and q.
- Then the gcd returned is n*e, cancelling which we would be left with
- p1 + p2 and q0.
- Note that further factoring of p1 + p2 and q0 might be possible with
- sympy.factor (which uses domain-specific theories). E.g., we are unable
- to find that x*y + x + y + 1 is divisible by x + 1. More generally,
- when q is of the form q1 + q2 (instead of being already factored) it
- might be necessary to fall back on sympy.gcd.
- """
- def integer_coefficient(x: sympy.Basic) -> int:
- integer_coefficients: list[int] = [
- abs(int(arg))
- for arg in sympy.Mul.make_args(x)
- if isinstance(arg, (int, sympy.Integer))
- ]
- return math.prod(integer_coefficients)
- def integer_factor(expr: sympy.Basic) -> int:
- integer_factors: Iterable[int] = map(
- integer_coefficient, sympy.Add.make_args(expr)
- )
- return functools.reduce(math.gcd, integer_factors)
- gcd: int = math.gcd(integer_factor(p), integer_factor(q))
- p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12
- base_splits: list[tuple[sympy.Basic, ...]] = list(
- map(sympy.Mul.make_args, sympy.Add.make_args(p))
- )
- divisor_split: tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
- for x in divisor_split:
- if all(x in base_split for base_split in base_splits):
- gcd = gcd * x # type: ignore[operator] # remove in py3.12
- return gcd # type: ignore[return-value] # remove in py3.12
- # It would be nice to have assertions on whether or not inputs is_integer
- # However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
- # sometimes inconsistently reports floats an integers.
- #
- # What we can assume from sympy is that if something is an int, it
- # definitely is is_integer, but if it is a float it may or may not
- # be is_integer. So we are unable to do strong asserts that things
- # are NOT integers.
- # TODO: In Triton, // rounds to zero, but in Python, it is floor division.
- # When we can prove both arguments are non-negative, we should just have a
- # GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
- # and then PythonFloorDiv and CIntDiv which have the appropriate rounding
- # semantics.
- #
- # Right now, FloorDiv de facto changes behavior if arguments are negative or
- # not, this can potentially cause correctness issues.
- class FloorDiv(sympy.Function):
- """
- We maintain this so that:
- 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
- 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
- NB: This is Python-style floor division, round to -Inf
- """
- nargs: tuple[int, ...] = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer: bool = True
- @property
- def base(self) -> sympy.Basic:
- return self.args[0]
- @property
- def divisor(self) -> sympy.Basic:
- return self.args[1]
- def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
- base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5)
- divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5)
- return f"({base}//{divisor})"
- # Automatic evaluation.
- # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
- @classmethod
- def eval(
- cls, base: sympy.Integer, divisor: sympy.Integer
- ) -> Union[sympy.Basic, None]:
- # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
- # Assert triggered by inequality solver
- # assert base.is_integer, base
- # assert divisor.is_integer, divisor
- # We don't provide the same error message as in Python because SymPy
- # makes it difficult to check the types.
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
- int_oo,
- -int_oo,
- sympy.oo,
- -sympy.oo,
- ):
- return sympy.nan
- if base is sympy.nan or divisor is sympy.nan:
- return sympy.nan
- if base.is_zero:
- return sympy.S.Zero
- if base.is_integer and equal_valued(divisor, 1):
- return base
- if base.is_integer and equal_valued(divisor, -1):
- return sympy.Mul(base, -1)
- if (
- isinstance(base, sympy.Number)
- and isinstance(divisor, sympy.Number)
- and (
- base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
- or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
- )
- ):
- r = float(base) / float(divisor)
- if r == math.inf:
- return int_oo
- elif r == -math.inf:
- return -int_oo
- elif math.isnan(r):
- return sympy.nan
- else:
- return sympy.Integer(math.floor(r))
- if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
- return sympy.Integer(int(base) // int(divisor))
- if isinstance(base, FloorDiv):
- return FloorDiv(base.args[0], base.args[1] * divisor)
- # Expands (x + y) // b into x // b + y // b.
- # This only works if floor is an identity, i.e. x / b is an integer.
- if isinstance(divisor, sympy.Integer):
- quotients = 0
- terms = []
- for term in sympy.Add.make_args(base):
- quotient = term / divisor
- if quotient.is_integer:
- terms.append(term)
- quotients += quotient
- if len(terms) != 0:
- # Passing evaluate = False since expression will be optimized during the subtraction post its construction.
- return (
- FloorDiv(base - sympy.Add(*terms, evaluate=False), divisor)
- + quotients
- )
- try:
- gcd = simple_floordiv_gcd(base, divisor)
- if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
- gcd = sympy.gcd(base, divisor)
- if not equal_valued(gcd, 1):
- return FloorDiv(
- sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
- )
- except sympy.PolynomialError:
- pass # https://github.com/pytorch/pytorch/issues/108276
- return None
- def _ccode(self, printer):
- base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5)
- divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5)
- return f"floor({base}/{divisor})"
- class ModularIndexing(sympy.Function):
- """
- ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
- """
- nargs: tuple[int, ...] = (3,)
- is_integer: bool = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(
- cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
- ) -> Optional[sympy.Basic]:
- if base == 0 or modulus == 1:
- return sympy.S.Zero
- if (
- isinstance(base, sympy.Integer)
- and isinstance(divisor, sympy.Integer)
- and isinstance(modulus, sympy.Integer)
- ):
- return (base // divisor) % modulus
- try:
- if divisor != 1:
- gcd = sympy.gcd(base, divisor)
- if gcd != 1:
- return ModularIndexing(
- sympy.simplify(base / gcd),
- sympy.simplify(divisor / gcd),
- modulus,
- )
- except sympy.PolynomialError:
- pass # https://github.com/pytorch/pytorch/issues/108276
- if isinstance(base, sympy.Add):
- new_terms: list[sympy.Integer] = []
- all_positive: bool = True
- for term in base.args:
- if sympy.gcd(term, modulus * divisor) != modulus * divisor:
- if (isinstance(term, sympy.Integer) and term < 0) or (
- isinstance(term, sympy.Mul)
- and isinstance(term.args[0], sympy.Integer)
- and term.args[0] < 0
- ):
- # workaround for https://github.com/triton-lang/triton/issues/619,
- # if there are negative terms, // produces wrong result
- # TODO if https://github.com/triton-lang/triton/issues/619 is fixed
- # this optimization would become valid
- all_positive = False
- break
- else:
- new_terms.append(term)
- if len(new_terms) != len(base.args) and all_positive:
- return ModularIndexing(sum(new_terms), divisor, modulus)
- if isinstance(base, FloorDiv):
- return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
- return None
- def _eval_is_nonnegative(self) -> Optional[bool]:
- p, q = self.args[:2]
- return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
- class Where(sympy.Function):
- """
- Good ol' ternary operator
- """
- nargs: tuple[int, ...] = (3,)
- precedence: int = 35 # lower precedence than add
- def _eval_is_integer(self) -> Optional[bool]:
- return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
- def _eval_is_nonnegative(self) -> Optional[bool]:
- return (
- True
- if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
- else None
- )
- def _eval_is_positive(self) -> Optional[bool]:
- return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
- @classmethod
- def eval(
- cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
- ) -> Optional[sympy.Basic]:
- if c == sympy.true:
- return p
- elif c == sympy.false:
- return q
- return None
- # Python-style modulus: take sign from RHS
- class PythonMod(sympy.Function):
- nargs: tuple[int, ...] = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer: bool = True
- @classmethod
- def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
- # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
- # Triggered by sympy.solvers.inequalities.reduce_inequalities
- # assert p.is_integer, p
- # assert q.is_integer, q
- if q.is_zero:
- raise ZeroDivisionError("Modulo by zero")
- # Three cases:
- # 1. p == 0
- # 2. p is either q or -q
- # 3. p is integer and q == 1
- if p is S.Zero or p in (q, -q) or q == 1:
- return S.Zero
- # Evaluate if they are both literals.
- if q.is_Number and p.is_Number:
- return p % q
- # If q == 2, it's a matter of whether p is odd or even.
- if q.is_Number and q == 2:
- if p.is_even:
- return S.Zero
- if p.is_odd:
- return S.One
- # If p is a multiple of q.
- r = p / q
- if r.is_integer:
- return S.Zero
- # If p < q and its ratio is positive, then:
- # - floor(p / q) = 0
- # - p % q = p - floor(p / q) * q = p
- less = p < q
- if less.is_Boolean and bool(less) and r.is_positive:
- return p
- if sympy.Mod(p, q) == 0:
- return S.Zero
- return None
- # NB: args[1] for PythonMod
- def _eval_is_nonnegative(self) -> Optional[bool]:
- return True if self.args[1].is_positive else None # type: ignore[attr-defined]
- def _eval_is_nonpositive(self) -> Optional[bool]:
- return True if self.args[1].is_negative else None # type: ignore[attr-defined]
- def _ccode(self, printer):
- p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
- q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
- abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
- return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
- # Generic modulus: only defined on non-negative arguments
- class Mod(sympy.Function):
- nargs = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer = True
- is_nonnegative = True
- @classmethod
- def eval(cls, p, q):
- # This was adapted from: sympy/core/mod.py
- # Triggered by
- # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
- # assert p.is_integer, p
- # assert q.is_integer, q
- if q.is_zero:
- raise ZeroDivisionError("Modulo by zero")
- # Three cases:
- # 1. p == 0
- # 2. p is either q or -q
- # 3. p is integer and q == 1
- if p is S.Zero or p in (q, -q) or q == 1:
- return S.Zero
- # Evaluate if they are both literals.
- if q.is_Number and p.is_Number:
- assert p >= 0, p
- assert q >= 1, q
- return p % q
- # If q == 2, it's a matter of whether p is odd or even.
- if q.is_Number and q == 2:
- if p.is_even:
- return S.Zero
- if p.is_odd:
- return S.One
- # If p is a multiple of q.
- r = p / q
- if r.is_integer:
- return S.Zero
- # If p < q and its ratio is positive, then:
- # - floor(p / q) = 0
- # - p % q = p - floor(p / q) * q = p
- less = p < q
- if less.is_Boolean and bool(less) and r.is_positive:
- return p
- class CleanDiv(FloorDiv):
- """
- Div where we can assume no rounding.
- This is to enable future optimizations.
- """
- # Don't use sympy ceiling/floor as they will attempt simplifications involving
- # frac
- class CeilToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, -int_oo):
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.ceil(float(number)))
- def _ccode(self, printer):
- number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
- return f"ceil({number})"
- class FloorToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, int_oo):
- return -int_oo
- if isinstance(number, sympy.Integer):
- return number
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.floor(float(number)))
- class CeilDiv(sympy.Function):
- """
- Div used in indexing that rounds up.
- """
- is_integer = True
- def __new__(cls, base, divisor):
- base = sympy.sympify(base)
- divisor = sympy.sympify(divisor)
- if sympy.gcd(base, divisor) == divisor:
- return CleanDiv(base, divisor)
- else:
- return FloorDiv(base + (divisor - 1), divisor)
- class LShift(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, base, shift):
- if shift < 0:
- raise ValueError("negative shift count")
- return base * 2**shift
- class RShift(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, base, shift):
- if shift < 0:
- raise ValueError("negative shift count")
- return FloorDiv(base, 2**shift)
- class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
- def __new__(cls, *original_args, **assumptions):
- from sympy.core.parameters import global_parameters
- evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
- args = (sympify(arg) for arg in original_args)
- # See the comment in _satisfy_unique_summations_symbols.
- unique_summations_symbols = (
- None
- if not evaluate
- else cls._satisfy_unique_summations_symbols(original_args)
- )
- if evaluate:
- try:
- # first standard filter, for cls.zero and cls.identity
- # also reshape Max(a, Max(b, c)) to Max(a, b, c)
- args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment]
- except ShortCircuit:
- return cls.zero # type: ignore[attr-defined]
- # No need to run _collapse_arguments and _find_localzeros, see the comment
- # in _satisfy_unique_summations_symbols.
- if unique_summations_symbols is None:
- # remove redundant args that are easily identified
- args = cls._collapse_arguments(args, **assumptions)
- # find local zeros
- args = cls._find_localzeros(args, **assumptions)
- args = frozenset(args)
- if not args:
- return cls.identity # type: ignore[attr-defined]
- if len(args) == 1:
- return list(args).pop()
- # base creation
- obj = Expr.__new__(cls, *ordered(args), **assumptions)
- obj._argset = args
- obj.unique_summations_symbols = unique_summations_symbols
- return obj
- @classmethod
- def _satisfy_unique_summations_symbols(
- cls, args
- ) -> Optional[set[sympy.core.symbol.Symbol]]:
- """
- One common case in some models is building expressions of the form
- max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
- For such expressions, we call the Max constructor X times (once for each nested
- max) and the expression gets flattened.
- An expensive cost in constructing those expressions is running _collapse_arguments
- and _find_localzeros. However, those two optimizations are unnecessary when the args
- to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols.
- This function is used to detect such properties of the expressions we are building
- and if so inform that we do not need to run those optimizations. To detect those,
- we store a property in the expression that tells that this expression is a min/max
- operation over terms that use unique symbols "unique_summations_symbols". This property
- also memoize the set of symbols used in all the terms to make it faster to detect this
- property inductively.
- When we apply max to add a new term, all we need to do is check if the new term uses
- unique symbols (with respect to existing terms and itself).
- Example:
- t = Max(a+b, c+d) ==> satisfies the property
- Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property.
- The function returns None if the new expression does not satisfy the unique_summations_symbols
- property. Otherwise, it returns a new set of unique symbols.
- """
- if len(args) != 2:
- return None
- (lhs, rhs) = (
- (args[1], args[0])
- if isinstance(args[1], MinMaxBase)
- else (args[0], args[1])
- )
- if not _is_symbols_binary_summation(rhs):
- return None
- # base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols.
- if _is_symbols_binary_summation(lhs):
- return cls._unique_symbols(args)
- # inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols
- if isinstance(lhs, MinMaxBase):
- lhs_unique_summations_symbols = getattr(
- lhs, "unique_summations_symbols", None
- )
- if lhs_unique_summations_symbols is not None:
- return cls._unique_symbols([rhs], lhs_unique_summations_symbols)
- return None
- @classmethod
- def _unique_symbols(
- cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
- ) -> Optional[set[sympy.core.symbol.Symbol]]:
- """
- Return seen_symbols if all atoms in all args are all unique symbols,
- else returns None. initial_set can be used to represent initial value for seen_symbols
- """
- seen_symbols = set() if initial_set is None else initial_set
- for arg in args:
- for element in arg.atoms():
- if not isinstance(element, sympy.core.symbol.Symbol):
- return None
- elif element in seen_symbols:
- return None
- else:
- seen_symbols.add(element)
- return seen_symbols
- @classmethod
- def _collapse_arguments(cls, args, **assumptions):
- """Remove redundant args.
- Examples
- ========
- >>> from sympy import Min, Max
- >>> from sympy.abc import a, b, c, d, e
- Any arg in parent that appears in any
- parent-like function in any of the flat args
- of parent can be removed from that sub-arg:
- >>> Min(a, Max(b, Min(a, c, d)))
- Min(a, Max(b, Min(c, d)))
- If the arg of parent appears in an opposite-than parent
- function in any of the flat args of parent that function
- can be replaced with the arg:
- >>> Min(a, Max(b, Min(c, d, Max(a, e))))
- Min(a, Max(b, Min(a, c, d)))
- """
- if not args:
- return args
- args = list(ordered(args))
- if cls is Min:
- other = Max
- else:
- other = Min # type: ignore[assignment]
- # find global comparable max of Max and min of Min if a new
- # value is being introduced in these args at position 0 of
- # the ordered args
- if args[0].is_number:
- sifted = mins, maxs = [], [] # type: ignore[var-annotated]
- for i in args:
- for v in walk(i, Min, Max):
- if v.args[0].is_comparable:
- sifted[isinstance(v, Max)].append(v)
- small = Min.identity
- for i in mins:
- v = i.args[0]
- if v.is_number and (v < small) == True: # noqa: E712
- small = v
- big = Max.identity
- for i in maxs:
- v = i.args[0]
- if v.is_number and (v > big) == True: # noqa: E712
- big = v
- # at the point when this function is called from __new__,
- # there may be more than one numeric arg present since
- # local zeros have not been handled yet, so look through
- # more than the first arg
- if cls is Min:
- for arg in args:
- if not arg.is_number:
- break
- if (arg < small) == True: # noqa: E712
- small = arg
- elif cls == Max:
- for arg in args:
- if not arg.is_number:
- break
- if (arg > big) == True: # noqa: E712
- big = arg
- T = None
- if cls is Min:
- if small != Min.identity:
- other = Max
- T = small
- elif big != Max.identity:
- other = Min # type: ignore[assignment]
- T = big
- if T is not None:
- # remove numerical redundancy
- for i in range(len(args)):
- a = args[i]
- if isinstance(a, other):
- a0 = a.args[0]
- if ( # noqa: E712
- (a0 > T) if other == Max else (a0 < T) # noqa: E712
- ) == True: # noqa: E712
- args[i] = cls.identity # type: ignore[attr-defined]
- # remove redundant symbolic args
- def do(ai, a):
- if not isinstance(ai, (Min, Max)):
- return ai
- cond = a in ai.args
- if not cond:
- return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
- if isinstance(ai, cls):
- return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
- return a
- for i, a in enumerate(args):
- args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]]
- # factor out common elements as for
- # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z))
- # and vice versa when swapping Min/Max -- do this only for the
- # easy case where all functions contain something in common;
- # trying to find some optimal subset of args to modify takes
- # too long
- def factor_minmax(args):
- is_other = lambda arg: isinstance(arg, other) # noqa: E731
- other_args, remaining_args = sift(args, is_other, binary=True)
- if not other_args:
- return args
- # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v})
- arg_sets = [set(arg.args) for arg in other_args]
- common = set.intersection(*arg_sets)
- if not common:
- return args
- new_other_args = list(common)
- arg_sets_diff = [arg_set - common for arg_set in arg_sets]
- # If any set is empty after removing common then all can be
- # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b)
- if all(arg_sets_diff):
- other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff]
- new_other_args.append(cls(*other_args_diff, evaluate=False))
- other_args_factored = other(*new_other_args, evaluate=False)
- return remaining_args + [other_args_factored]
- if len(args) > 1:
- args = factor_minmax(args)
- return args
- @classmethod
- def _new_args_filter(cls, arg_sequence):
- """
- Generator filtering args.
- first standard filter, for cls.zero and cls.identity.
- Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``,
- and check arguments for comparability
- """
- for arg in arg_sequence:
- # pre-filter, checking comparability of arguments
- if (
- not isinstance(arg, Expr)
- or arg.is_extended_real is False
- or (arg.is_number and not arg.is_comparable)
- ):
- raise ValueError(f"The argument '{arg}' is not comparable.")
- if arg == cls.zero: # type: ignore[attr-defined]
- raise ShortCircuit(arg)
- elif arg == cls.identity: # type: ignore[attr-defined]
- continue
- elif arg.func == cls:
- yield from arg.args
- else:
- yield arg
- @classmethod
- def _find_localzeros(cls, values, **options):
- """
- Sequentially allocate values to localzeros.
- When a value is identified as being more extreme than another member it
- replaces that member; if this is never true, then the value is simply
- appended to the localzeros.
- Unlike the sympy implementation, we only look for zero and one, we don't
- do generic is connected test pairwise which is slow
- """
- # First, collapse all numeric arguments
- other_values = set()
- num_value = None
- for arg in values:
- if arg.is_Number:
- if num_value is None:
- num_value = arg
- else:
- if cls is Max:
- num_value = max(num_value, arg)
- elif cls is Min:
- num_value = min(num_value, arg)
- else:
- raise AssertionError(f"impossible {cls}")
- else:
- other_values.add(arg)
- # Special cases when there is only one symbolic value
- if num_value is None:
- return other_values
- if len(other_values) == 0:
- return {num_value}
- if len(other_values) == 1:
- other_value = next(iter(other_values))
- if num_value in (0.0, 0) and other_value.is_nonnegative:
- return other_values if cls is Max else {num_value}
- if num_value == 1 and other_value.is_positive:
- return other_values if cls is Max else {num_value}
- other_values.add(num_value)
- return other_values
- _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
- _eval_is_antihermitian = lambda s: _torf( # noqa: E731
- i.is_antihermitian
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_commutative = lambda s: _torf( # noqa: E731
- i.is_commutative
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
- _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
- _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731
- _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731
- _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731
- _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731
- _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731
- _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731
- _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731
- _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
- _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
- _eval_is_nonnegative = lambda s: _torf( # noqa: E731
- i.is_nonnegative
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_nonpositive = lambda s: _torf( # noqa: E731
- i.is_nonpositive
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
- _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
- _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731
- _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731
- _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731
- _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
- _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
- _eval_is_extended_real = lambda s: _torf( # noqa: E731
- i.is_extended_real
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_transcendental = lambda s: _torf( # noqa: E731
- i.is_transcendental
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
- class Max(MinMaxBase, Application): # type: ignore[misc]
- r"""
- Return, if possible, the maximum value of the list.
- """
- zero = S.Infinity
- identity = S.NegativeInfinity
- def _eval_is_positive(self): # type:ignore[override]
- return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined]
- def _eval_is_nonnegative(self): # type:ignore[override]
- return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
- def _eval_is_negative(self): # type:ignore[override]
- return fuzzy_and(a.is_negative for a in self.args)
- class Min(MinMaxBase, Application): # type: ignore[misc]
- """
- Return, if possible, the minimum value of the list.
- """
- zero = S.NegativeInfinity
- identity = S.Infinity
- def _eval_is_positive(self): # type:ignore[override]
- return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined]
- def _eval_is_nonnegative(self): # type:ignore[override]
- return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
- def _eval_is_negative(self): # type:ignore[override]
- return fuzzy_or(a.is_negative for a in self.args)
- def safe_pow(base, exp):
- sign = 1
- if base < 0:
- base = -base
- sign = 1 if exp % 2 == 0 else -1
- return sign * _safe_pow(base, exp)
- # Prevent people from overflowing pow
- def _safe_pow(base, exponent):
- if exponent < 0:
- raise ValueError("Exponent must be non-negative.")
- if exponent == 0:
- return 1
- half_exp = safe_pow(base, exponent // 2)
- if half_exp is int_oo:
- return int_oo
- # TODO: microoptimization is to avoid overflowing into arbitrary precision
- # and detect overflow prior to doing operations
- result = half_exp * half_exp
- if result > sys.maxsize:
- return int_oo
- if exponent % 2 == 1:
- result *= base
- if result > sys.maxsize:
- return int_oo
- return result
- class PowByNatural(sympy.Function):
- is_integer = True
- precedence: int = 50 # precedence of mul
- @classmethod
- def eval(cls, base, exp):
- if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
- r = safe_pow(base, exp)
- if r in (-int_oo, int_oo):
- return r
- return sympy.Integer(r)
- if isinstance(exp, sympy.Integer):
- # Rely on regular sympy Pow for this (note that iterated
- # multiplication turns into a Pow anyway, you can't escape!!)
- return sympy.Pow(base, exp)
- if exp in (int_oo, sympy.oo):
- if base.is_nonnegative:
- return int_oo
- elif base.is_negative:
- return sympy.zoo # this is apparently what (-2)**sympy.oo does
- # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
- # is a natural number if we do
- # base is assumed to be nonnegative, thereby prevent complex numbers from
- # occurring
- class FloatPow(sympy.Function):
- is_real = True
- precedence: int = 60 # precedence of pow
- @classmethod
- def eval(cls, base, exp):
- # NB: These test sympy.Number, not sympy.Float, because:
- # - Sometimes we may have sympy.oo or int_oo, and that's not a Float
- # (but coerces to math.Inf)
- # - Sometimes Float(0.0) will unpredictably decay to Integer(0),
- # but we should still accept it in floatey contexts
- if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
- return sympy.Float(float(base) ** float(exp))
- # NB: do not do any nontrivial reasoning
- # Overloaded to be compatible with regular Python.
- # https://github.com/pytorch/pytorch/issues/90900
- #
- # In particular, sympy division is willing to simplify x/x == 1
- # where 1 is an integer, but this must be a float if x was float.
- class FloatTrueDiv(sympy.Function):
- is_real = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(cls, base, divisor):
- # assert base.is_integer is not True, base
- # assert divisor.is_integer is not True, divisor
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
- return sympy.Float(float(base) / float(divisor))
- # Overloaded to be compatible with regular Python. We distinguish this from
- # FloatTrueDiv, because the code generation has to be different for this case:
- # Python has a fancy algorithm for integer true division that isn't just
- # "promote both arguments to float and use float division", so you need to
- # codegen it differently. While technically you can work it out from the
- # types of the input, this is often inconvenient to do in Inductor codegen,
- # so just have a different operator
- # NB: Right now, Inductor codegen doesn't implement this correctly lol
- class IntTrueDiv(sympy.Function):
- is_real = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(cls, base, divisor):
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if (
- isinstance(base, sympy.Number)
- and isinstance(divisor, sympy.Number)
- and (
- base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
- or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
- )
- ):
- # Don't have to worry about precision here, you're getting zero or
- # inf from the division
- return sympy.Float(float(base) / float(divisor))
- if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
- return sympy.Float(int(base) / int(divisor))
- def _ccode(self, printer):
- base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
- divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
- return f"((int){base}/(int){divisor})"
- # TODO: As an indicator, this != 0 implies == 1 (and vice versa).
- # Because we do not have the ability to guard on the stride permutation
- # at the moment, it is hard to make further inferences when this is true,
- # as although we know the tensor is contiguous in *some* layout, we don't
- # know which one (however, you could, for example, make the inference that
- # reshaping this to a 1D tensor can be guard-free.)
- class IsNonOverlappingAndDenseIndicator(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, *args):
- assert len(args) % 2 == 0
- dim = len(args) // 2
- sizes = args[0:dim]
- strides = args[dim:]
- # sym_node imported in torch.__init__. Local import to avoid an import cycle
- from torch.fx.experimental.symbolic_shapes import (
- eval_is_non_overlapping_and_dense,
- )
- if all(isinstance(a, sympy.Integer) for a in args):
- return eval_is_non_overlapping_and_dense(
- [int(a) for a in sizes], [int(a) for a in strides]
- )
- if dim == 1:
- # Manually implement the rank one short circuit
- if strides[0].is_Number and strides[0] == 1:
- return 1
- if sizes[0].is_Number and sizes[0] < 2:
- return 1
- # return 0 case covered by case above
- # TODO: Inability to access size-obliviousness sucks: if we have a
- # size oblivious test on a size-like unbacked SymInt, we could
- # confidently return zero when we have a size-like u0 stride
- # and a size-like u1 size. Maybe a fancy ValueRanges analysis for
- # this function could help figure this out.
- if all(isinstance(a, sympy.Integer) for a in strides):
- assert dim != 0
- # When all strides are integral, we can sort, and the size for the
- # largest stride doesn't matter and can be arbitrarily symbolic
- s_sizes, s_strides = zip(
- *sorted(zip(sizes, strides, strict=False), key=operator.itemgetter(1)),
- strict=False,
- )
- # Put something arbitrary in the max size spot, it'll be ignored
- if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
- s_sizes = s_sizes[:-1] + (42,)
- # We can reuse the regular eval, because it is invariant to
- # permutation of dimensions
- return eval_is_non_overlapping_and_dense(
- [int(a) for a in s_sizes], [int(a) for a in s_strides]
- )
- return None
- # NB: this is inconsistent with math.trunc in Python
- class TruncToFloat(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if isinstance(number, sympy.Number):
- # NB: It is safe to use truncation to integer, which is what
- # math.trunc does, as Python integers are arbitrary precision and
- # so we are guaranteed not to lose precision when we do this
- return sympy.Float(math.trunc(float(number)))
- class TruncToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, -int_oo):
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.trunc(float(number)))
- # This is float -> int
- class RoundToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number is sympy.oo:
- return int_oo
- if number is -sympy.oo:
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(round(float(number), 0))
- # To get float -> int, Python style round semantics.
- #
- # x = PyFloat_AsDouble(self);
- # if (o_ndigits == Py_None) {
- # /* single-argument round or with None ndigits:
- # * round to nearest integer */
- # rounded = round(x);
- # if (fabs(x-rounded) == 0.5)
- # /* halfway case: round to even */
- # rounded = 2.0*round(x/2.0);
- # return PyLong_FromDouble(rounded);
- # }
- # NB: Like Round, this only ever returns floats. ndigits cannot be None
- class RoundDecimal(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number, ndigits):
- # assert number.is_integer is not True, number
- if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
- return sympy.Float(round(float(number), int(ndigits)))
- class ToFloat(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number):
- if number in [sympy.oo, -sympy.oo]:
- return number
- if isinstance(number, sympy.Integer):
- return sympy.Float(int(number))
- if number is int_oo:
- return sympy.oo
- if number is -int_oo:
- return -sympy.oo
- class Identity(sympy.Function):
- """
- Prevents expansion and other optimizations
- """
- precedence = 10
- def __repr__(self): # type: ignore[override]
- return f"Identity({self.args[0]})"
- def _eval_is_real(self):
- return self.args[0].is_real
- def _eval_is_integer(self):
- return self.args[0].is_integer # type: ignore[attr-defined]
- def _eval_expand_identity(self, **hints):
- # Removes the identity op.
- return self.args[0]
- def __int__(self) -> int:
- return int(self.args[0])
- def __float__(self) -> float:
- return float(self.args[0])
- def make_opaque_unary_fn(name):
- class OpaqueUnaryFn(sympy.Function):
- """
- Unlike the builtin sympy functions on real numbers like sympy.sqrt,
- these equivalents do not do any nontrivial reasoning besides
- constant propagation. This helps avoid performing transformations
- that are valid for real numbers but are invalid for floating point;
- in particular, while we are willing to make optimizations that change
- numerics for Tensor compute, we are NOT willing to make optimizations
- that change numerics for size compute.
- """
- _torch_handler_name = name
- _torch_unpickler = make_opaque_unary_fn
- @classmethod
- def eval(cls, a):
- if isinstance(a, (sympy.Integer, sympy.Float)):
- # Python converts to float64 before computing, c.f.
- # >>> math.sin(2**53+1)
- # -0.848925964814655
- # >>> math.sin(float(2**53+1))
- # -0.848925964814655
- try:
- return sympy.Float(getattr(math, name)(float(a)))
- # Just use sympy semantics for infinity/overflow, you might get some
- # weird objects but ask silly questions, get silly answers
- except OverflowError:
- return getattr(sympy, name)(a)
- elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
- if a is int_oo:
- a = sympy.oo
- if a is -int_oo:
- a = -sympy.oo
- if name == "log2":
- return sympy.log(a, 2)
- return getattr(sympy, name)(a)
- return None
- nm = "OpaqueUnaryFn_" + name
- OpaqueUnaryFn.__name__ = nm
- OpaqueUnaryFn.__qualname__ = nm
- return OpaqueUnaryFn
- # Keep in sync with math_op_names in torch/fx/experimental/sym_node.py
- OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt")
- OpaqueUnaryFn_cos = make_opaque_unary_fn("cos")
- OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh")
- OpaqueUnaryFn_sin = make_opaque_unary_fn("sin")
- OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh")
- OpaqueUnaryFn_tan = make_opaque_unary_fn("tan")
- OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh")
- OpaqueUnaryFn_asin = make_opaque_unary_fn("asin")
- OpaqueUnaryFn_acos = make_opaque_unary_fn("acos")
- OpaqueUnaryFn_atan = make_opaque_unary_fn("atan")
- OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
- OpaqueUnaryFn_log = make_opaque_unary_fn("log")
- OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
- OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
- def make_opaque_bitwise_fn(name, real_op_name):
- if name == "bitwise_and":
- prec = PRECEDENCE["BitwiseAnd"]
- elif name == "bitwise_or":
- prec = PRECEDENCE["BitwiseOr"]
- else:
- raise AssertionError(f"unrecognized {name}")
- class BitwiseFn(sympy.Function):
- _torch_handler_name = name
- precedence: int = prec
- _torch_unpickler = functools.partial(
- make_opaque_bitwise_fn, real_op_name=real_op_name
- )
- @classmethod
- def eval(cls, a, b):
- if a.is_Boolean and b.is_Boolean:
- return getattr(operator, real_op_name)(a, b)
- if a.is_Boolean:
- a = sympy.Integer(1 if a else 0)
- if b.is_Boolean:
- b = sympy.Integer(1 if b else 0)
- if isinstance(a, (sympy.Integer, int)) and isinstance(
- b, (sympy.Integer, int)
- ):
- return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
- return None
- nm = "BitwiseFn_" + name
- BitwiseFn.__name__ = nm
- BitwiseFn.__qualname__ = nm
- return BitwiseFn
- BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
- BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
|