test_operator.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from sympy.core.function import (Derivative, Function, diff)
  2. from sympy.core.mul import Mul
  3. from sympy.core.numbers import (Integer, pi)
  4. from sympy.core.symbol import (Symbol, symbols)
  5. from sympy.core.sympify import sympify
  6. from sympy.functions.elementary.trigonometric import sin
  7. from sympy.physics.quantum.qexpr import QExpr
  8. from sympy.physics.quantum.dagger import Dagger
  9. from sympy.physics.quantum.hilbert import HilbertSpace
  10. from sympy.physics.quantum.operator import (Operator, UnitaryOperator,
  11. HermitianOperator, OuterProduct,
  12. DifferentialOperator,
  13. IdentityOperator)
  14. from sympy.physics.quantum.state import Ket, Bra, Wavefunction
  15. from sympy.physics.quantum.qapply import qapply
  16. from sympy.physics.quantum.represent import represent
  17. from sympy.physics.quantum.spin import JzKet, JzBra
  18. from sympy.physics.quantum.trace import Tr
  19. from sympy.matrices import eye
  20. from sympy.testing.pytest import warns_deprecated_sympy
  21. class CustomKet(Ket):
  22. @classmethod
  23. def default_args(self):
  24. return ("t",)
  25. class CustomOp(HermitianOperator):
  26. @classmethod
  27. def default_args(self):
  28. return ("T",)
  29. t_ket = CustomKet()
  30. t_op = CustomOp()
  31. def test_operator():
  32. A = Operator('A')
  33. B = Operator('B')
  34. C = Operator('C')
  35. assert isinstance(A, Operator)
  36. assert isinstance(A, QExpr)
  37. assert A.label == (Symbol('A'),)
  38. assert A.is_commutative is False
  39. assert A.hilbert_space == HilbertSpace()
  40. assert A*B != B*A
  41. assert (A*(B + C)).expand() == A*B + A*C
  42. assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2
  43. assert t_op.label[0] == Symbol(t_op.default_args()[0])
  44. assert Operator() == Operator("O")
  45. with warns_deprecated_sympy():
  46. assert A*IdentityOperator() == A
  47. def test_operator_inv():
  48. A = Operator('A')
  49. assert A*A.inv() == 1
  50. assert A.inv()*A == 1
  51. def test_hermitian():
  52. H = HermitianOperator('H')
  53. assert isinstance(H, HermitianOperator)
  54. assert isinstance(H, Operator)
  55. assert Dagger(H) == H
  56. assert H.inv() != H
  57. assert H.is_commutative is False
  58. assert Dagger(H).is_commutative is False
  59. def test_unitary():
  60. U = UnitaryOperator('U')
  61. assert isinstance(U, UnitaryOperator)
  62. assert isinstance(U, Operator)
  63. assert U.inv() == Dagger(U)
  64. assert U*Dagger(U) == 1
  65. assert Dagger(U)*U == 1
  66. assert U.is_commutative is False
  67. assert Dagger(U).is_commutative is False
  68. def test_identity():
  69. with warns_deprecated_sympy():
  70. I = IdentityOperator()
  71. O = Operator('O')
  72. x = Symbol("x")
  73. three = sympify(3)
  74. assert isinstance(I, IdentityOperator)
  75. assert isinstance(I, Operator)
  76. assert I * O == O
  77. assert O * I == O
  78. assert I * Dagger(O) == Dagger(O)
  79. assert Dagger(O) * I == Dagger(O)
  80. assert isinstance(I * I, IdentityOperator)
  81. assert three * I == three
  82. assert I * x == x
  83. assert I.inv() == I
  84. assert Dagger(I) == I
  85. assert qapply(I * O) == O
  86. assert qapply(O * I) == O
  87. for n in [2, 3, 5]:
  88. assert represent(IdentityOperator(n)) == eye(n)
  89. def test_outer_product():
  90. k = Ket('k')
  91. b = Bra('b')
  92. op = OuterProduct(k, b)
  93. assert isinstance(op, OuterProduct)
  94. assert isinstance(op, Operator)
  95. assert op.ket == k
  96. assert op.bra == b
  97. assert op.label == (k, b)
  98. assert op.is_commutative is False
  99. op = k*b
  100. assert isinstance(op, OuterProduct)
  101. assert isinstance(op, Operator)
  102. assert op.ket == k
  103. assert op.bra == b
  104. assert op.label == (k, b)
  105. assert op.is_commutative is False
  106. op = 2*k*b
  107. assert op == Mul(Integer(2), k, b)
  108. op = 2*(k*b)
  109. assert op == Mul(Integer(2), OuterProduct(k, b))
  110. assert Dagger(k*b) == OuterProduct(Dagger(b), Dagger(k))
  111. assert Dagger(k*b).is_commutative is False
  112. #test the _eval_trace
  113. assert Tr(OuterProduct(JzKet(1, 1), JzBra(1, 1))).doit() == 1
  114. # test scaled kets and bras
  115. assert OuterProduct(2 * k, b) == 2 * OuterProduct(k, b)
  116. assert OuterProduct(k, 2 * b) == 2 * OuterProduct(k, b)
  117. # test sums of kets and bras
  118. k1, k2 = Ket('k1'), Ket('k2')
  119. b1, b2 = Bra('b1'), Bra('b2')
  120. assert (OuterProduct(k1 + k2, b1) ==
  121. OuterProduct(k1, b1) + OuterProduct(k2, b1))
  122. assert (OuterProduct(k1, b1 + b2) ==
  123. OuterProduct(k1, b1) + OuterProduct(k1, b2))
  124. assert (OuterProduct(1 * k1 + 2 * k2, 3 * b1 + 4 * b2) ==
  125. 3 * OuterProduct(k1, b1) +
  126. 4 * OuterProduct(k1, b2) +
  127. 6 * OuterProduct(k2, b1) +
  128. 8 * OuterProduct(k2, b2))
  129. def test_operator_dagger():
  130. A = Operator('A')
  131. B = Operator('B')
  132. assert Dagger(A*B) == Dagger(B)*Dagger(A)
  133. assert Dagger(A + B) == Dagger(A) + Dagger(B)
  134. assert Dagger(A**2) == Dagger(A)**2
  135. def test_differential_operator():
  136. x = Symbol('x')
  137. f = Function('f')
  138. d = DifferentialOperator(Derivative(f(x), x), f(x))
  139. g = Wavefunction(x**2, x)
  140. assert qapply(d*g) == Wavefunction(2*x, x)
  141. assert d.expr == Derivative(f(x), x)
  142. assert d.function == f(x)
  143. assert d.variables == (x,)
  144. assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 2), f(x))
  145. d = DifferentialOperator(Derivative(f(x), x, 2), f(x))
  146. g = Wavefunction(x**3, x)
  147. assert qapply(d*g) == Wavefunction(6*x, x)
  148. assert d.expr == Derivative(f(x), x, 2)
  149. assert d.function == f(x)
  150. assert d.variables == (x,)
  151. assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 3), f(x))
  152. d = DifferentialOperator(1/x*Derivative(f(x), x), f(x))
  153. assert d.expr == 1/x*Derivative(f(x), x)
  154. assert d.function == f(x)
  155. assert d.variables == (x,)
  156. assert diff(d, x) == \
  157. DifferentialOperator(Derivative(1/x*Derivative(f(x), x), x), f(x))
  158. assert qapply(d*g) == Wavefunction(3*x, x)
  159. # 2D cartesian Laplacian
  160. y = Symbol('y')
  161. d = DifferentialOperator(Derivative(f(x, y), x, 2) +
  162. Derivative(f(x, y), y, 2), f(x, y))
  163. w = Wavefunction(x**3*y**2 + y**3*x**2, x, y)
  164. assert d.expr == Derivative(f(x, y), x, 2) + Derivative(f(x, y), y, 2)
  165. assert d.function == f(x, y)
  166. assert d.variables == (x, y)
  167. assert diff(d, x) == \
  168. DifferentialOperator(Derivative(d.expr, x), f(x, y))
  169. assert diff(d, y) == \
  170. DifferentialOperator(Derivative(d.expr, y), f(x, y))
  171. assert qapply(d*w) == Wavefunction(2*x**3 + 6*x*y**2 + 6*x**2*y + 2*y**3,
  172. x, y)
  173. # 2D polar Laplacian (th = theta)
  174. r, th = symbols('r th')
  175. d = DifferentialOperator(1/r*Derivative(r*Derivative(f(r, th), r), r) +
  176. 1/(r**2)*Derivative(f(r, th), th, 2), f(r, th))
  177. w = Wavefunction(r**2*sin(th), r, (th, 0, pi))
  178. assert d.expr == \
  179. 1/r*Derivative(r*Derivative(f(r, th), r), r) + \
  180. 1/(r**2)*Derivative(f(r, th), th, 2)
  181. assert d.function == f(r, th)
  182. assert d.variables == (r, th)
  183. assert diff(d, r) == \
  184. DifferentialOperator(Derivative(d.expr, r), f(r, th))
  185. assert diff(d, th) == \
  186. DifferentialOperator(Derivative(d.expr, th), f(r, th))
  187. assert qapply(d*w) == Wavefunction(3*sin(th), r, (th, 0, pi))
  188. def test_eval_power():
  189. from sympy.core import Pow
  190. from sympy.core.expr import unchanged
  191. O = Operator('O')
  192. U = UnitaryOperator('U')
  193. H = HermitianOperator('H')
  194. assert O**-1 == O.inv() # same as doc test
  195. assert U**-1 == U.inv()
  196. assert H**-1 == H.inv()
  197. x = symbols("x", commutative = True)
  198. assert unchanged(Pow, H, x) # verify Pow(H,x)=="X^n"
  199. assert H**x == Pow(H, x)
  200. assert Pow(H,x) == Pow(H, x, evaluate=False) # Just check
  201. from sympy.physics.quantum.gate import XGate
  202. X = XGate(0) # is hermitian and unitary
  203. assert unchanged(Pow, X, x) # verify Pow(X,x)=="X^x"
  204. assert X**x == Pow(X, x)
  205. assert Pow(X, x, evaluate=False) == Pow(X, x) # Just check
  206. n = symbols("n", integer=True, even=True)
  207. assert X**n == 1
  208. n = symbols("n", integer=True, odd=True)
  209. assert X**n == X
  210. n = symbols("n", integer=True)
  211. assert unchanged(Pow, X, n) # verify Pow(X,n)=="X^n"
  212. assert X**n == Pow(X, n)
  213. assert Pow(X, n, evaluate=False)==Pow(X, n) # Just check
  214. assert X**4 == 1
  215. assert X**7 == X