test_matmul.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import operator
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Series,
  7. )
  8. import pandas._testing as tm
  9. class TestMatmul:
  10. def test_matmul(self):
  11. # matmul test is for GH#10259
  12. a = Series(
  13. np.random.default_rng(2).standard_normal(4), index=["p", "q", "r", "s"]
  14. )
  15. b = DataFrame(
  16. np.random.default_rng(2).standard_normal((3, 4)),
  17. index=["1", "2", "3"],
  18. columns=["p", "q", "r", "s"],
  19. ).T
  20. # Series @ DataFrame -> Series
  21. result = operator.matmul(a, b)
  22. expected = Series(np.dot(a.values, b.values), index=["1", "2", "3"])
  23. tm.assert_series_equal(result, expected)
  24. # DataFrame @ Series -> Series
  25. result = operator.matmul(b.T, a)
  26. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  27. tm.assert_series_equal(result, expected)
  28. # Series @ Series -> scalar
  29. result = operator.matmul(a, a)
  30. expected = np.dot(a.values, a.values)
  31. tm.assert_almost_equal(result, expected)
  32. # GH#21530
  33. # vector (1D np.array) @ Series (__rmatmul__)
  34. result = operator.matmul(a.values, a)
  35. expected = np.dot(a.values, a.values)
  36. tm.assert_almost_equal(result, expected)
  37. # GH#21530
  38. # vector (1D list) @ Series (__rmatmul__)
  39. result = operator.matmul(a.values.tolist(), a)
  40. expected = np.dot(a.values, a.values)
  41. tm.assert_almost_equal(result, expected)
  42. # GH#21530
  43. # matrix (2D np.array) @ Series (__rmatmul__)
  44. result = operator.matmul(b.T.values, a)
  45. expected = np.dot(b.T.values, a.values)
  46. tm.assert_almost_equal(result, expected)
  47. # GH#21530
  48. # matrix (2D nested lists) @ Series (__rmatmul__)
  49. result = operator.matmul(b.T.values.tolist(), a)
  50. expected = np.dot(b.T.values, a.values)
  51. tm.assert_almost_equal(result, expected)
  52. # mixed dtype DataFrame @ Series
  53. a["p"] = int(a.p)
  54. result = operator.matmul(b.T, a)
  55. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  56. tm.assert_series_equal(result, expected)
  57. # different dtypes DataFrame @ Series
  58. a = a.astype(int)
  59. result = operator.matmul(b.T, a)
  60. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  61. tm.assert_series_equal(result, expected)
  62. msg = r"Dot product shape mismatch, \(4,\) vs \(3,\)"
  63. # exception raised is of type Exception
  64. with pytest.raises(Exception, match=msg):
  65. a.dot(a.values[:3])
  66. msg = "matrices are not aligned"
  67. with pytest.raises(ValueError, match=msg):
  68. a.dot(b.T)