tensorproduct.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """Abstract tensor product."""
  2. from sympy.core.add import Add
  3. from sympy.core.expr import Expr
  4. from sympy.core.kind import KindDispatcher
  5. from sympy.core.mul import Mul
  6. from sympy.core.power import Pow
  7. from sympy.core.sympify import sympify
  8. from sympy.matrices.dense import DenseMatrix as Matrix
  9. from sympy.matrices.immutable import ImmutableDenseMatrix as ImmutableMatrix
  10. from sympy.printing.pretty.stringpict import prettyForm
  11. from sympy.utilities.exceptions import sympy_deprecation_warning
  12. from sympy.physics.quantum.dagger import Dagger
  13. from sympy.physics.quantum.kind import (
  14. KetKind, _KetKind,
  15. BraKind, _BraKind,
  16. OperatorKind, _OperatorKind
  17. )
  18. from sympy.physics.quantum.matrixutils import (
  19. numpy_ndarray,
  20. scipy_sparse_matrix,
  21. matrix_tensor_product
  22. )
  23. from sympy.physics.quantum.state import Ket, Bra
  24. from sympy.physics.quantum.trace import Tr
  25. __all__ = [
  26. 'TensorProduct',
  27. 'tensor_product_simp'
  28. ]
  29. #-----------------------------------------------------------------------------
  30. # Tensor product
  31. #-----------------------------------------------------------------------------
  32. _combined_printing = False
  33. def combined_tensor_printing(combined):
  34. """Set flag controlling whether tensor products of states should be
  35. printed as a combined bra/ket or as an explicit tensor product of different
  36. bra/kets. This is a global setting for all TensorProduct class instances.
  37. Parameters
  38. ----------
  39. combine : bool
  40. When true, tensor product states are combined into one ket/bra, and
  41. when false explicit tensor product notation is used between each
  42. ket/bra.
  43. """
  44. global _combined_printing
  45. _combined_printing = combined
  46. class TensorProduct(Expr):
  47. """The tensor product of two or more arguments.
  48. For matrices, this uses ``matrix_tensor_product`` to compute the Kronecker
  49. or tensor product matrix. For other objects a symbolic ``TensorProduct``
  50. instance is returned. The tensor product is a non-commutative
  51. multiplication that is used primarily with operators and states in quantum
  52. mechanics.
  53. Currently, the tensor product distinguishes between commutative and
  54. non-commutative arguments. Commutative arguments are assumed to be scalars
  55. and are pulled out in front of the ``TensorProduct``. Non-commutative
  56. arguments remain in the resulting ``TensorProduct``.
  57. Parameters
  58. ==========
  59. args : tuple
  60. A sequence of the objects to take the tensor product of.
  61. Examples
  62. ========
  63. Start with a simple tensor product of SymPy matrices::
  64. >>> from sympy import Matrix
  65. >>> from sympy.physics.quantum import TensorProduct
  66. >>> m1 = Matrix([[1,2],[3,4]])
  67. >>> m2 = Matrix([[1,0],[0,1]])
  68. >>> TensorProduct(m1, m2)
  69. Matrix([
  70. [1, 0, 2, 0],
  71. [0, 1, 0, 2],
  72. [3, 0, 4, 0],
  73. [0, 3, 0, 4]])
  74. >>> TensorProduct(m2, m1)
  75. Matrix([
  76. [1, 2, 0, 0],
  77. [3, 4, 0, 0],
  78. [0, 0, 1, 2],
  79. [0, 0, 3, 4]])
  80. We can also construct tensor products of non-commutative symbols:
  81. >>> from sympy import Symbol
  82. >>> A = Symbol('A',commutative=False)
  83. >>> B = Symbol('B',commutative=False)
  84. >>> tp = TensorProduct(A, B)
  85. >>> tp
  86. AxB
  87. We can take the dagger of a tensor product (note the order does NOT reverse
  88. like the dagger of a normal product):
  89. >>> from sympy.physics.quantum import Dagger
  90. >>> Dagger(tp)
  91. Dagger(A)xDagger(B)
  92. Expand can be used to distribute a tensor product across addition:
  93. >>> C = Symbol('C',commutative=False)
  94. >>> tp = TensorProduct(A+B,C)
  95. >>> tp
  96. (A + B)xC
  97. >>> tp.expand(tensorproduct=True)
  98. AxC + BxC
  99. """
  100. is_commutative = False
  101. _kind_dispatcher = KindDispatcher("TensorProduct_kind_dispatcher", commutative=True)
  102. @property
  103. def kind(self):
  104. """Calculate the kind of a tensor product by looking at its children."""
  105. arg_kinds = (a.kind for a in self.args)
  106. return self._kind_dispatcher(*arg_kinds)
  107. def __new__(cls, *args):
  108. if isinstance(args[0], (Matrix, ImmutableMatrix, numpy_ndarray,
  109. scipy_sparse_matrix)):
  110. return matrix_tensor_product(*args)
  111. c_part, new_args = cls.flatten(sympify(args))
  112. c_part = Mul(*c_part)
  113. if len(new_args) == 0:
  114. return c_part
  115. elif len(new_args) == 1:
  116. return c_part * new_args[0]
  117. else:
  118. tp = Expr.__new__(cls, *new_args)
  119. return c_part * tp
  120. @classmethod
  121. def flatten(cls, args):
  122. # TODO: disallow nested TensorProducts.
  123. c_part = []
  124. nc_parts = []
  125. for arg in args:
  126. cp, ncp = arg.args_cnc()
  127. c_part.extend(list(cp))
  128. nc_parts.append(Mul._from_args(ncp))
  129. return c_part, nc_parts
  130. def _eval_adjoint(self):
  131. return TensorProduct(*[Dagger(i) for i in self.args])
  132. def _eval_rewrite(self, rule, args, **hints):
  133. return TensorProduct(*args).expand(tensorproduct=True)
  134. def _sympystr(self, printer, *args):
  135. length = len(self.args)
  136. s = ''
  137. for i in range(length):
  138. if isinstance(self.args[i], (Add, Pow, Mul)):
  139. s = s + '('
  140. s = s + printer._print(self.args[i])
  141. if isinstance(self.args[i], (Add, Pow, Mul)):
  142. s = s + ')'
  143. if i != length - 1:
  144. s = s + 'x'
  145. return s
  146. def _pretty(self, printer, *args):
  147. if (_combined_printing and
  148. (all(isinstance(arg, Ket) for arg in self.args) or
  149. all(isinstance(arg, Bra) for arg in self.args))):
  150. length = len(self.args)
  151. pform = printer._print('', *args)
  152. for i in range(length):
  153. next_pform = printer._print('', *args)
  154. length_i = len(self.args[i].args)
  155. for j in range(length_i):
  156. part_pform = printer._print(self.args[i].args[j], *args)
  157. next_pform = prettyForm(*next_pform.right(part_pform))
  158. if j != length_i - 1:
  159. next_pform = prettyForm(*next_pform.right(', '))
  160. if len(self.args[i].args) > 1:
  161. next_pform = prettyForm(
  162. *next_pform.parens(left='{', right='}'))
  163. pform = prettyForm(*pform.right(next_pform))
  164. if i != length - 1:
  165. pform = prettyForm(*pform.right(',' + ' '))
  166. pform = prettyForm(*pform.left(self.args[0].lbracket))
  167. pform = prettyForm(*pform.right(self.args[0].rbracket))
  168. return pform
  169. length = len(self.args)
  170. pform = printer._print('', *args)
  171. for i in range(length):
  172. next_pform = printer._print(self.args[i], *args)
  173. if isinstance(self.args[i], (Add, Mul)):
  174. next_pform = prettyForm(
  175. *next_pform.parens(left='(', right=')')
  176. )
  177. pform = prettyForm(*pform.right(next_pform))
  178. if i != length - 1:
  179. if printer._use_unicode:
  180. pform = prettyForm(*pform.right('\N{N-ARY CIRCLED TIMES OPERATOR}' + ' '))
  181. else:
  182. pform = prettyForm(*pform.right('x' + ' '))
  183. return pform
  184. def _latex(self, printer, *args):
  185. if (_combined_printing and
  186. (all(isinstance(arg, Ket) for arg in self.args) or
  187. all(isinstance(arg, Bra) for arg in self.args))):
  188. def _label_wrap(label, nlabels):
  189. return label if nlabels == 1 else r"\left\{%s\right\}" % label
  190. s = r", ".join([_label_wrap(arg._print_label_latex(printer, *args),
  191. len(arg.args)) for arg in self.args])
  192. return r"{%s%s%s}" % (self.args[0].lbracket_latex, s,
  193. self.args[0].rbracket_latex)
  194. length = len(self.args)
  195. s = ''
  196. for i in range(length):
  197. if isinstance(self.args[i], (Add, Mul)):
  198. s = s + '\\left('
  199. # The extra {} brackets are needed to get matplotlib's latex
  200. # rendered to render this properly.
  201. s = s + '{' + printer._print(self.args[i], *args) + '}'
  202. if isinstance(self.args[i], (Add, Mul)):
  203. s = s + '\\right)'
  204. if i != length - 1:
  205. s = s + '\\otimes '
  206. return s
  207. def doit(self, **hints):
  208. return TensorProduct(*[item.doit(**hints) for item in self.args])
  209. def _eval_expand_tensorproduct(self, **hints):
  210. """Distribute TensorProducts across addition."""
  211. args = self.args
  212. add_args = []
  213. for i in range(len(args)):
  214. if isinstance(args[i], Add):
  215. for aa in args[i].args:
  216. tp = TensorProduct(*args[:i] + (aa,) + args[i + 1:])
  217. c_part, nc_part = tp.args_cnc()
  218. # Check for TensorProduct object: is the one object in nc_part, if any:
  219. # (Note: any other object type to be expanded must be added here)
  220. if len(nc_part) == 1 and isinstance(nc_part[0], TensorProduct):
  221. nc_part = (nc_part[0]._eval_expand_tensorproduct(), )
  222. add_args.append(Mul(*c_part)*Mul(*nc_part))
  223. break
  224. if add_args:
  225. return Add(*add_args)
  226. else:
  227. return self
  228. def _eval_trace(self, **kwargs):
  229. indices = kwargs.get('indices', None)
  230. exp = self
  231. if indices is None or len(indices) == 0:
  232. return Mul(*[Tr(arg).doit() for arg in exp.args])
  233. else:
  234. return Mul(*[Tr(value).doit() if idx in indices else value
  235. for idx, value in enumerate(exp.args)])
  236. def tensor_product_simp_Mul(e):
  237. """Simplify a Mul with tensor products.
  238. .. deprecated:: 1.14.
  239. The transformations applied by this function are not done automatically
  240. when tensor products are combined.
  241. Originally, the main use of this function is to simplify a ``Mul`` of
  242. ``TensorProduct``s to a ``TensorProduct`` of ``Muls``.
  243. """
  244. sympy_deprecation_warning(
  245. """
  246. tensor_product_simp_Mul has been deprecated. The transformations
  247. performed by this function are now done automatically when
  248. tensor products are multiplied.
  249. """,
  250. deprecated_since_version="1.14",
  251. active_deprecations_target='deprecated-tensorproduct-simp'
  252. )
  253. return e
  254. def tensor_product_simp_Pow(e):
  255. """Evaluates ``Pow`` expressions whose base is ``TensorProduct``
  256. .. deprecated:: 1.14.
  257. The transformations applied by this function are not done automatically
  258. when tensor products are combined.
  259. """
  260. sympy_deprecation_warning(
  261. """
  262. tensor_product_simp_Pow has been deprecated. The transformations
  263. performed by this function are now done automatically when
  264. tensor products are exponentiated.
  265. """,
  266. deprecated_since_version="1.14",
  267. active_deprecations_target='deprecated-tensorproduct-simp'
  268. )
  269. return e
  270. def tensor_product_simp(e, **hints):
  271. """Try to simplify and combine tensor products.
  272. .. deprecated:: 1.14.
  273. The transformations applied by this function are not done automatically
  274. when tensor products are combined.
  275. Originally, this function tried to pull expressions inside of ``TensorProducts``.
  276. It only worked for relatively simple cases where the products have
  277. only scalars, raw ``TensorProducts``, not ``Add``, ``Pow``, ``Commutators``
  278. of ``TensorProducts``.
  279. """
  280. sympy_deprecation_warning(
  281. """
  282. tensor_product_simp has been deprecated. The transformations
  283. performed by this function are now done automatically when
  284. tensor products are combined.
  285. """,
  286. deprecated_since_version="1.14",
  287. active_deprecations_target='deprecated-tensorproduct-simp'
  288. )
  289. return e
  290. @TensorProduct._kind_dispatcher.register(_OperatorKind, _OperatorKind)
  291. def find_op_kind(e1, e2):
  292. return OperatorKind
  293. @TensorProduct._kind_dispatcher.register(_KetKind, _KetKind)
  294. def find_ket_kind(e1, e2):
  295. return KetKind
  296. @TensorProduct._kind_dispatcher.register(_BraKind, _BraKind)
  297. def find_bra_kind(e1, e2):
  298. return BraKind