test_compilation.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import shutil
  2. import os
  3. import subprocess
  4. import tempfile
  5. from sympy.external import import_module
  6. from sympy.testing.pytest import skip, skip_under_pyodide
  7. from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath
  8. numpy = import_module('numpy')
  9. cython = import_module('cython')
  10. _sources1 = [
  11. ('sigmoid.c', r"""
  12. #include <math.h>
  13. void sigmoid(int n, const double * const restrict in,
  14. double * const restrict out, double lim){
  15. for (int i=0; i<n; ++i){
  16. const double x = in[i];
  17. out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
  18. }
  19. }
  20. """),
  21. ('_sigmoid.pyx', r"""
  22. import numpy as np
  23. cimport numpy as cnp
  24. cdef extern void c_sigmoid "sigmoid" (int, const double * const,
  25. double * const, double)
  26. def sigmoid(double [:] inp, double lim=350.0):
  27. cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
  28. inp.size, dtype=np.float64)
  29. c_sigmoid(inp.size, &inp[0], &out[0], lim)
  30. return out
  31. """)
  32. ]
  33. def npy(data, lim=350.0):
  34. return data/((data/lim)**8+1)**(1/8.)
  35. def test_compile_link_import_strings():
  36. if not numpy:
  37. skip("numpy not installed.")
  38. if not cython:
  39. skip("cython not installed.")
  40. from sympy.utilities._compilation import has_c
  41. if not has_c():
  42. skip("No C compiler found.")
  43. compile_kw = {"std": 'c99', "include_dirs": [numpy.get_include()]}
  44. info = None
  45. try:
  46. mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
  47. data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
  48. res_mod = mod.sigmoid(data)
  49. res_npy = npy(data)
  50. assert numpy.allclose(res_mod, res_npy)
  51. finally:
  52. if info and info['build_dir']:
  53. shutil.rmtree(info['build_dir'])
  54. @skip_under_pyodide("Emscripten does not support subprocesses")
  55. def test_compile_sources():
  56. tmpdir = tempfile.mkdtemp()
  57. from sympy.utilities._compilation import has_c
  58. if not has_c():
  59. skip("No C compiler found.")
  60. build_dir = str(tmpdir)
  61. _handle, file_path = tempfile.mkstemp('.c', dir=build_dir)
  62. with open(file_path, 'wt') as ofh:
  63. ofh.write("""
  64. int foo(int bar) {
  65. return 2*bar;
  66. }
  67. """)
  68. obj, = compile_sources([file_path], cwd=build_dir)
  69. obj_path = get_abspath(obj, cwd=build_dir)
  70. assert os.path.exists(obj_path)
  71. try:
  72. _ = subprocess.check_output(["nm", "--help"])
  73. except subprocess.CalledProcessError:
  74. pass # we cannot test contents of object file
  75. else:
  76. nm_out = subprocess.check_output(["nm", obj_path])
  77. assert 'foo' in nm_out.decode('utf-8')
  78. if not cython:
  79. return # the final (optional) part of the test below requires Cython.
  80. _handle, pyx_path = tempfile.mkstemp('.pyx', dir=build_dir)
  81. with open(pyx_path, 'wt') as ofh:
  82. ofh.write(("cdef extern int foo(int)\n"
  83. "def _foo(arg):\n"
  84. " return foo(arg)"))
  85. mod = compile_link_import_py_ext([pyx_path], extra_objs=[obj_path], build_dir=build_dir)
  86. assert mod._foo(21) == 42