_parse_latex_antlr.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # Ported from latex2sympy by @augustt198
  2. # https://github.com/augustt198/latex2sympy
  3. # See license in LICENSE.txt
  4. from importlib.metadata import version
  5. import sympy
  6. from sympy.external import import_module
  7. from sympy.printing.str import StrPrinter
  8. from sympy.physics.quantum.state import Bra, Ket
  9. from .errors import LaTeXParsingError
  10. LaTeXParser = LaTeXLexer = MathErrorListener = None
  11. try:
  12. LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser',
  13. import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser
  14. LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer',
  15. import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer
  16. except Exception:
  17. pass
  18. ErrorListener = import_module('antlr4.error.ErrorListener',
  19. warn_not_installed=True,
  20. import_kwargs={'fromlist': ['ErrorListener']}
  21. )
  22. if ErrorListener:
  23. class MathErrorListener(ErrorListener.ErrorListener): # type:ignore # noqa:F811
  24. def __init__(self, src):
  25. super(ErrorListener.ErrorListener, self).__init__()
  26. self.src = src
  27. def syntaxError(self, recog, symbol, line, col, msg, e):
  28. fmt = "%s\n%s\n%s"
  29. marker = "~" * col + "^"
  30. if msg.startswith("missing"):
  31. err = fmt % (msg, self.src, marker)
  32. elif msg.startswith("no viable"):
  33. err = fmt % ("I expected something else here", self.src, marker)
  34. elif msg.startswith("mismatched"):
  35. names = LaTeXParser.literalNames
  36. expected = [
  37. names[i] for i in e.getExpectedTokens() if i < len(names)
  38. ]
  39. if len(expected) < 10:
  40. expected = " ".join(expected)
  41. err = (fmt % ("I expected one of these: " + expected, self.src,
  42. marker))
  43. else:
  44. err = (fmt % ("I expected something else here", self.src,
  45. marker))
  46. else:
  47. err = fmt % ("I don't understand this", self.src, marker)
  48. raise LaTeXParsingError(err)
  49. def parse_latex(sympy, strict=False):
  50. antlr4 = import_module('antlr4')
  51. if None in [antlr4, MathErrorListener] or \
  52. not version('antlr4-python3-runtime').startswith('4.11'):
  53. raise ImportError("LaTeX parsing requires the antlr4 Python package,"
  54. " provided by pip (antlr4-python3-runtime) or"
  55. " conda (antlr-python-runtime), version 4.11")
  56. sympy = sympy.strip()
  57. matherror = MathErrorListener(sympy)
  58. stream = antlr4.InputStream(sympy)
  59. lex = LaTeXLexer(stream)
  60. lex.removeErrorListeners()
  61. lex.addErrorListener(matherror)
  62. tokens = antlr4.CommonTokenStream(lex)
  63. parser = LaTeXParser(tokens)
  64. # remove default console error listener
  65. parser.removeErrorListeners()
  66. parser.addErrorListener(matherror)
  67. relation = parser.math().relation()
  68. if strict and (relation.start.start != 0 or relation.stop.stop != len(sympy) - 1):
  69. raise LaTeXParsingError("Invalid LaTeX")
  70. expr = convert_relation(relation)
  71. return expr
  72. def convert_relation(rel):
  73. if rel.expr():
  74. return convert_expr(rel.expr())
  75. lh = convert_relation(rel.relation(0))
  76. rh = convert_relation(rel.relation(1))
  77. if rel.LT():
  78. return sympy.StrictLessThan(lh, rh)
  79. elif rel.LTE():
  80. return sympy.LessThan(lh, rh)
  81. elif rel.GT():
  82. return sympy.StrictGreaterThan(lh, rh)
  83. elif rel.GTE():
  84. return sympy.GreaterThan(lh, rh)
  85. elif rel.EQUAL():
  86. return sympy.Eq(lh, rh)
  87. elif rel.NEQ():
  88. return sympy.Ne(lh, rh)
  89. def convert_expr(expr):
  90. return convert_add(expr.additive())
  91. def convert_add(add):
  92. if add.ADD():
  93. lh = convert_add(add.additive(0))
  94. rh = convert_add(add.additive(1))
  95. return sympy.Add(lh, rh, evaluate=False)
  96. elif add.SUB():
  97. lh = convert_add(add.additive(0))
  98. rh = convert_add(add.additive(1))
  99. if hasattr(rh, "is_Atom") and rh.is_Atom:
  100. return sympy.Add(lh, -1 * rh, evaluate=False)
  101. return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False)
  102. else:
  103. return convert_mp(add.mp())
  104. def convert_mp(mp):
  105. if hasattr(mp, 'mp'):
  106. mp_left = mp.mp(0)
  107. mp_right = mp.mp(1)
  108. else:
  109. mp_left = mp.mp_nofunc(0)
  110. mp_right = mp.mp_nofunc(1)
  111. if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
  112. lh = convert_mp(mp_left)
  113. rh = convert_mp(mp_right)
  114. return sympy.Mul(lh, rh, evaluate=False)
  115. elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
  116. lh = convert_mp(mp_left)
  117. rh = convert_mp(mp_right)
  118. return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
  119. else:
  120. if hasattr(mp, 'unary'):
  121. return convert_unary(mp.unary())
  122. else:
  123. return convert_unary(mp.unary_nofunc())
  124. def convert_unary(unary):
  125. if hasattr(unary, 'unary'):
  126. nested_unary = unary.unary()
  127. else:
  128. nested_unary = unary.unary_nofunc()
  129. if hasattr(unary, 'postfix_nofunc'):
  130. first = unary.postfix()
  131. tail = unary.postfix_nofunc()
  132. postfix = [first] + tail
  133. else:
  134. postfix = unary.postfix()
  135. if unary.ADD():
  136. return convert_unary(nested_unary)
  137. elif unary.SUB():
  138. numabs = convert_unary(nested_unary)
  139. # Use Integer(-n) instead of Mul(-1, n)
  140. return -numabs
  141. elif postfix:
  142. return convert_postfix_list(postfix)
  143. def convert_postfix_list(arr, i=0):
  144. if i >= len(arr):
  145. raise LaTeXParsingError("Index out of bounds")
  146. res = convert_postfix(arr[i])
  147. if isinstance(res, sympy.Expr):
  148. if i == len(arr) - 1:
  149. return res # nothing to multiply by
  150. else:
  151. if i > 0:
  152. left = convert_postfix(arr[i - 1])
  153. right = convert_postfix(arr[i + 1])
  154. if isinstance(left, sympy.Expr) and isinstance(
  155. right, sympy.Expr):
  156. left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol)
  157. right_syms = convert_postfix(arr[i + 1]).atoms(
  158. sympy.Symbol)
  159. # if the left and right sides contain no variables and the
  160. # symbol in between is 'x', treat as multiplication.
  161. if not (left_syms or right_syms) and str(res) == 'x':
  162. return convert_postfix_list(arr, i + 1)
  163. # multiply by next
  164. return sympy.Mul(
  165. res, convert_postfix_list(arr, i + 1), evaluate=False)
  166. else: # must be derivative
  167. wrt = res[0]
  168. if i == len(arr) - 1:
  169. raise LaTeXParsingError("Expected expression for derivative")
  170. else:
  171. expr = convert_postfix_list(arr, i + 1)
  172. return sympy.Derivative(expr, wrt)
  173. def do_subs(expr, at):
  174. if at.expr():
  175. at_expr = convert_expr(at.expr())
  176. syms = at_expr.atoms(sympy.Symbol)
  177. if len(syms) == 0:
  178. return expr
  179. elif len(syms) > 0:
  180. sym = next(iter(syms))
  181. return expr.subs(sym, at_expr)
  182. elif at.equality():
  183. lh = convert_expr(at.equality().expr(0))
  184. rh = convert_expr(at.equality().expr(1))
  185. return expr.subs(lh, rh)
  186. def convert_postfix(postfix):
  187. if hasattr(postfix, 'exp'):
  188. exp_nested = postfix.exp()
  189. else:
  190. exp_nested = postfix.exp_nofunc()
  191. exp = convert_exp(exp_nested)
  192. for op in postfix.postfix_op():
  193. if op.BANG():
  194. if isinstance(exp, list):
  195. raise LaTeXParsingError("Cannot apply postfix to derivative")
  196. exp = sympy.factorial(exp, evaluate=False)
  197. elif op.eval_at():
  198. ev = op.eval_at()
  199. at_b = None
  200. at_a = None
  201. if ev.eval_at_sup():
  202. at_b = do_subs(exp, ev.eval_at_sup())
  203. if ev.eval_at_sub():
  204. at_a = do_subs(exp, ev.eval_at_sub())
  205. if at_b is not None and at_a is not None:
  206. exp = sympy.Add(at_b, -1 * at_a, evaluate=False)
  207. elif at_b is not None:
  208. exp = at_b
  209. elif at_a is not None:
  210. exp = at_a
  211. return exp
  212. def convert_exp(exp):
  213. if hasattr(exp, 'exp'):
  214. exp_nested = exp.exp()
  215. else:
  216. exp_nested = exp.exp_nofunc()
  217. if exp_nested:
  218. base = convert_exp(exp_nested)
  219. if isinstance(base, list):
  220. raise LaTeXParsingError("Cannot raise derivative to power")
  221. if exp.atom():
  222. exponent = convert_atom(exp.atom())
  223. elif exp.expr():
  224. exponent = convert_expr(exp.expr())
  225. return sympy.Pow(base, exponent, evaluate=False)
  226. else:
  227. if hasattr(exp, 'comp'):
  228. return convert_comp(exp.comp())
  229. else:
  230. return convert_comp(exp.comp_nofunc())
  231. def convert_comp(comp):
  232. if comp.group():
  233. return convert_expr(comp.group().expr())
  234. elif comp.abs_group():
  235. return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
  236. elif comp.atom():
  237. return convert_atom(comp.atom())
  238. elif comp.floor():
  239. return convert_floor(comp.floor())
  240. elif comp.ceil():
  241. return convert_ceil(comp.ceil())
  242. elif comp.func():
  243. return convert_func(comp.func())
  244. def convert_atom(atom):
  245. if atom.LETTER():
  246. sname = atom.LETTER().getText()
  247. if atom.subexpr():
  248. if atom.subexpr().expr(): # subscript is expr
  249. subscript = convert_expr(atom.subexpr().expr())
  250. else: # subscript is atom
  251. subscript = convert_atom(atom.subexpr().atom())
  252. sname += '_{' + StrPrinter().doprint(subscript) + '}'
  253. if atom.SINGLE_QUOTES():
  254. sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify
  255. return sympy.Symbol(sname)
  256. elif atom.SYMBOL():
  257. s = atom.SYMBOL().getText()[1:]
  258. if s == "infty":
  259. return sympy.oo
  260. else:
  261. if atom.subexpr():
  262. subscript = None
  263. if atom.subexpr().expr(): # subscript is expr
  264. subscript = convert_expr(atom.subexpr().expr())
  265. else: # subscript is atom
  266. subscript = convert_atom(atom.subexpr().atom())
  267. subscriptName = StrPrinter().doprint(subscript)
  268. s += '_{' + subscriptName + '}'
  269. return sympy.Symbol(s)
  270. elif atom.number():
  271. s = atom.number().getText().replace(",", "")
  272. return sympy.Number(s)
  273. elif atom.DIFFERENTIAL():
  274. var = get_differential_var(atom.DIFFERENTIAL())
  275. return sympy.Symbol('d' + var.name)
  276. elif atom.mathit():
  277. text = rule2text(atom.mathit().mathit_text())
  278. return sympy.Symbol(text)
  279. elif atom.frac():
  280. return convert_frac(atom.frac())
  281. elif atom.binom():
  282. return convert_binom(atom.binom())
  283. elif atom.bra():
  284. val = convert_expr(atom.bra().expr())
  285. return Bra(val)
  286. elif atom.ket():
  287. val = convert_expr(atom.ket().expr())
  288. return Ket(val)
  289. def rule2text(ctx):
  290. stream = ctx.start.getInputStream()
  291. # starting index of starting token
  292. startIdx = ctx.start.start
  293. # stopping index of stopping token
  294. stopIdx = ctx.stop.stop
  295. return stream.getText(startIdx, stopIdx)
  296. def convert_frac(frac):
  297. diff_op = False
  298. partial_op = False
  299. if frac.lower and frac.upper:
  300. lower_itv = frac.lower.getSourceInterval()
  301. lower_itv_len = lower_itv[1] - lower_itv[0] + 1
  302. if (frac.lower.start == frac.lower.stop
  303. and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL):
  304. wrt = get_differential_var_str(frac.lower.start.text)
  305. diff_op = True
  306. elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL
  307. and frac.lower.start.text == '\\partial'
  308. and (frac.lower.stop.type == LaTeXLexer.LETTER
  309. or frac.lower.stop.type == LaTeXLexer.SYMBOL)):
  310. partial_op = True
  311. wrt = frac.lower.stop.text
  312. if frac.lower.stop.type == LaTeXLexer.SYMBOL:
  313. wrt = wrt[1:]
  314. if diff_op or partial_op:
  315. wrt = sympy.Symbol(wrt)
  316. if (diff_op and frac.upper.start == frac.upper.stop
  317. and frac.upper.start.type == LaTeXLexer.LETTER
  318. and frac.upper.start.text == 'd'):
  319. return [wrt]
  320. elif (partial_op and frac.upper.start == frac.upper.stop
  321. and frac.upper.start.type == LaTeXLexer.SYMBOL
  322. and frac.upper.start.text == '\\partial'):
  323. return [wrt]
  324. upper_text = rule2text(frac.upper)
  325. expr_top = None
  326. if diff_op and upper_text.startswith('d'):
  327. expr_top = parse_latex(upper_text[1:])
  328. elif partial_op and frac.upper.start.text == '\\partial':
  329. expr_top = parse_latex(upper_text[len('\\partial'):])
  330. if expr_top:
  331. return sympy.Derivative(expr_top, wrt)
  332. if frac.upper:
  333. expr_top = convert_expr(frac.upper)
  334. else:
  335. expr_top = sympy.Number(frac.upperd.text)
  336. if frac.lower:
  337. expr_bot = convert_expr(frac.lower)
  338. else:
  339. expr_bot = sympy.Number(frac.lowerd.text)
  340. inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False)
  341. if expr_top == 1:
  342. return inverse_denom
  343. else:
  344. return sympy.Mul(expr_top, inverse_denom, evaluate=False)
  345. def convert_binom(binom):
  346. expr_n = convert_expr(binom.n)
  347. expr_k = convert_expr(binom.k)
  348. return sympy.binomial(expr_n, expr_k, evaluate=False)
  349. def convert_floor(floor):
  350. val = convert_expr(floor.val)
  351. return sympy.floor(val, evaluate=False)
  352. def convert_ceil(ceil):
  353. val = convert_expr(ceil.val)
  354. return sympy.ceiling(val, evaluate=False)
  355. def convert_func(func):
  356. if func.func_normal():
  357. if func.L_PAREN(): # function called with parenthesis
  358. arg = convert_func_arg(func.func_arg())
  359. else:
  360. arg = convert_func_arg(func.func_arg_noparens())
  361. name = func.func_normal().start.text[1:]
  362. # change arc<trig> -> a<trig>
  363. if name in [
  364. "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"
  365. ]:
  366. name = "a" + name[3:]
  367. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  368. if name in ["arsinh", "arcosh", "artanh"]:
  369. name = "a" + name[2:]
  370. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  371. if name == "exp":
  372. expr = sympy.exp(arg, evaluate=False)
  373. if name in ("log", "lg", "ln"):
  374. if func.subexpr():
  375. if func.subexpr().expr():
  376. base = convert_expr(func.subexpr().expr())
  377. else:
  378. base = convert_atom(func.subexpr().atom())
  379. elif name == "lg": # ISO 80000-2:2019
  380. base = 10
  381. elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default
  382. base = sympy.E
  383. expr = sympy.log(arg, base, evaluate=False)
  384. func_pow = None
  385. should_pow = True
  386. if func.supexpr():
  387. if func.supexpr().expr():
  388. func_pow = convert_expr(func.supexpr().expr())
  389. else:
  390. func_pow = convert_atom(func.supexpr().atom())
  391. if name in [
  392. "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh",
  393. "tanh"
  394. ]:
  395. if func_pow == -1:
  396. name = "a" + name
  397. should_pow = False
  398. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  399. if func_pow and should_pow:
  400. expr = sympy.Pow(expr, func_pow, evaluate=False)
  401. return expr
  402. elif func.LETTER() or func.SYMBOL():
  403. if func.LETTER():
  404. fname = func.LETTER().getText()
  405. elif func.SYMBOL():
  406. fname = func.SYMBOL().getText()[1:]
  407. fname = str(fname) # can't be unicode
  408. if func.subexpr():
  409. if func.subexpr().expr(): # subscript is expr
  410. subscript = convert_expr(func.subexpr().expr())
  411. else: # subscript is atom
  412. subscript = convert_atom(func.subexpr().atom())
  413. subscriptName = StrPrinter().doprint(subscript)
  414. fname += '_{' + subscriptName + '}'
  415. if func.SINGLE_QUOTES():
  416. fname += func.SINGLE_QUOTES().getText()
  417. input_args = func.args()
  418. output_args = []
  419. while input_args.args(): # handle multiple arguments to function
  420. output_args.append(convert_expr(input_args.expr()))
  421. input_args = input_args.args()
  422. output_args.append(convert_expr(input_args.expr()))
  423. return sympy.Function(fname)(*output_args)
  424. elif func.FUNC_INT():
  425. return handle_integral(func)
  426. elif func.FUNC_SQRT():
  427. expr = convert_expr(func.base)
  428. if func.root:
  429. r = convert_expr(func.root)
  430. return sympy.root(expr, r, evaluate=False)
  431. else:
  432. return sympy.sqrt(expr, evaluate=False)
  433. elif func.FUNC_OVERLINE():
  434. expr = convert_expr(func.base)
  435. return sympy.conjugate(expr, evaluate=False)
  436. elif func.FUNC_SUM():
  437. return handle_sum_or_prod(func, "summation")
  438. elif func.FUNC_PROD():
  439. return handle_sum_or_prod(func, "product")
  440. elif func.FUNC_LIM():
  441. return handle_limit(func)
  442. def convert_func_arg(arg):
  443. if hasattr(arg, 'expr'):
  444. return convert_expr(arg.expr())
  445. else:
  446. return convert_mp(arg.mp_nofunc())
  447. def handle_integral(func):
  448. if func.additive():
  449. integrand = convert_add(func.additive())
  450. elif func.frac():
  451. integrand = convert_frac(func.frac())
  452. else:
  453. integrand = 1
  454. int_var = None
  455. if func.DIFFERENTIAL():
  456. int_var = get_differential_var(func.DIFFERENTIAL())
  457. else:
  458. for sym in integrand.atoms(sympy.Symbol):
  459. s = str(sym)
  460. if len(s) > 1 and s[0] == 'd':
  461. if s[1] == '\\':
  462. int_var = sympy.Symbol(s[2:])
  463. else:
  464. int_var = sympy.Symbol(s[1:])
  465. int_sym = sym
  466. if int_var:
  467. integrand = integrand.subs(int_sym, 1)
  468. else:
  469. # Assume dx by default
  470. int_var = sympy.Symbol('x')
  471. if func.subexpr():
  472. if func.subexpr().atom():
  473. lower = convert_atom(func.subexpr().atom())
  474. else:
  475. lower = convert_expr(func.subexpr().expr())
  476. if func.supexpr().atom():
  477. upper = convert_atom(func.supexpr().atom())
  478. else:
  479. upper = convert_expr(func.supexpr().expr())
  480. return sympy.Integral(integrand, (int_var, lower, upper))
  481. else:
  482. return sympy.Integral(integrand, int_var)
  483. def handle_sum_or_prod(func, name):
  484. val = convert_mp(func.mp())
  485. iter_var = convert_expr(func.subeq().equality().expr(0))
  486. start = convert_expr(func.subeq().equality().expr(1))
  487. if func.supexpr().expr(): # ^{expr}
  488. end = convert_expr(func.supexpr().expr())
  489. else: # ^atom
  490. end = convert_atom(func.supexpr().atom())
  491. if name == "summation":
  492. return sympy.Sum(val, (iter_var, start, end))
  493. elif name == "product":
  494. return sympy.Product(val, (iter_var, start, end))
  495. def handle_limit(func):
  496. sub = func.limit_sub()
  497. if sub.LETTER():
  498. var = sympy.Symbol(sub.LETTER().getText())
  499. elif sub.SYMBOL():
  500. var = sympy.Symbol(sub.SYMBOL().getText()[1:])
  501. else:
  502. var = sympy.Symbol('x')
  503. if sub.SUB():
  504. direction = "-"
  505. elif sub.ADD():
  506. direction = "+"
  507. else:
  508. direction = "+-"
  509. approaching = convert_expr(sub.expr())
  510. content = convert_mp(func.mp())
  511. return sympy.Limit(content, var, approaching, direction)
  512. def get_differential_var(d):
  513. text = get_differential_var_str(d.getText())
  514. return sympy.Symbol(text)
  515. def get_differential_var_str(text):
  516. for i in range(1, len(text)):
  517. c = text[i]
  518. if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
  519. idx = i
  520. break
  521. text = text[idx:]
  522. if text[0] == "\\":
  523. text = text[1:]
  524. return text