test_matmul.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import operator
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Index,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. class TestMatMul:
  11. def test_matmul(self):
  12. # matmul test is for GH#10259
  13. a = DataFrame(
  14. np.random.default_rng(2).standard_normal((3, 4)),
  15. index=["a", "b", "c"],
  16. columns=["p", "q", "r", "s"],
  17. )
  18. b = DataFrame(
  19. np.random.default_rng(2).standard_normal((4, 2)),
  20. index=["p", "q", "r", "s"],
  21. columns=["one", "two"],
  22. )
  23. # DataFrame @ DataFrame
  24. result = operator.matmul(a, b)
  25. expected = DataFrame(
  26. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  27. )
  28. tm.assert_frame_equal(result, expected)
  29. # DataFrame @ Series
  30. result = operator.matmul(a, b.one)
  31. expected = Series(np.dot(a.values, b.one.values), index=["a", "b", "c"])
  32. tm.assert_series_equal(result, expected)
  33. # np.array @ DataFrame
  34. result = operator.matmul(a.values, b)
  35. assert isinstance(result, DataFrame)
  36. assert result.columns.equals(b.columns)
  37. assert result.index.equals(Index(range(3)))
  38. expected = np.dot(a.values, b.values)
  39. tm.assert_almost_equal(result.values, expected)
  40. # nested list @ DataFrame (__rmatmul__)
  41. result = operator.matmul(a.values.tolist(), b)
  42. expected = DataFrame(
  43. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  44. )
  45. tm.assert_almost_equal(result.values, expected.values)
  46. # mixed dtype DataFrame @ DataFrame
  47. a["q"] = a.q.round().astype(int)
  48. result = operator.matmul(a, b)
  49. expected = DataFrame(
  50. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  51. )
  52. tm.assert_frame_equal(result, expected)
  53. # different dtypes DataFrame @ DataFrame
  54. a = a.astype(int)
  55. result = operator.matmul(a, b)
  56. expected = DataFrame(
  57. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  58. )
  59. tm.assert_frame_equal(result, expected)
  60. # unaligned
  61. df = DataFrame(
  62. np.random.default_rng(2).standard_normal((3, 4)),
  63. index=[1, 2, 3],
  64. columns=range(4),
  65. )
  66. df2 = DataFrame(
  67. np.random.default_rng(2).standard_normal((5, 3)),
  68. index=range(5),
  69. columns=[1, 2, 3],
  70. )
  71. with pytest.raises(ValueError, match="aligned"):
  72. operator.matmul(df, df2)
  73. def test_matmul_message_shapes(self):
  74. # GH#21581 exception message should reflect original shapes,
  75. # not transposed shapes
  76. a = np.random.default_rng(2).random((10, 4))
  77. b = np.random.default_rng(2).random((5, 3))
  78. df = DataFrame(b)
  79. msg = r"shapes \(10, 4\) and \(5, 3\) not aligned"
  80. with pytest.raises(ValueError, match=msg):
  81. a @ df
  82. with pytest.raises(ValueError, match=msg):
  83. a.tolist() @ df