utils.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. import functools
  2. import inspect
  3. import sys
  4. import warnings
  5. from contextlib import contextmanager
  6. import numpy as np
  7. from ._warnings import all_warnings, warn
  8. __all__ = [
  9. 'deprecate_func',
  10. 'get_bound_method_class',
  11. 'all_warnings',
  12. 'safe_as_int',
  13. 'check_shape_equality',
  14. 'check_nD',
  15. 'warn',
  16. 'reshape_nd',
  17. 'identity',
  18. 'slice_at_axis',
  19. "deprecate_parameter",
  20. "DEPRECATED",
  21. ]
  22. def count_inner_wrappers(func):
  23. """Count the number of inner wrappers by unpacking ``__wrapped__``.
  24. If a wrapped function wraps another wrapped function, then we refer to the
  25. wrapping of the second function as an *inner wrapper*.
  26. For example, consider this code fragment:
  27. .. code-block:: python
  28. @wrap_outer
  29. @wrap_inner
  30. def foo():
  31. pass
  32. Here ``@wrap_inner`` applies a wrapper to ``foo``, and ``@wrap_outer``
  33. applies a wrapper to the result.
  34. Parameters
  35. ----------
  36. func : callable
  37. The callable of which to determine the number of inner wrappers.
  38. Returns
  39. -------
  40. count : int
  41. The number of times `func` has been wrapped.
  42. See Also
  43. --------
  44. count_global_wrappers
  45. """
  46. unwrapped = func
  47. count = 0
  48. while hasattr(unwrapped, "__wrapped__"):
  49. unwrapped = unwrapped.__wrapped__
  50. count += 1
  51. return count
  52. def _warning_stacklevel(func):
  53. """Find stacklevel of `func` relative to its global representation.
  54. Determine automatically with which stacklevel a warning should be raised.
  55. Parameters
  56. ----------
  57. func : Callable
  58. Tries to find the global version of `func` and counts the number of
  59. additional wrappers around `func`.
  60. Returns
  61. -------
  62. stacklevel : int
  63. The stacklevel. Minimum of 2.
  64. """
  65. # Count number of wrappers around `func`
  66. inner_wrapped_count = count_inner_wrappers(func)
  67. global_wrapped_count = count_global_wrappers(func)
  68. stacklevel = global_wrapped_count - inner_wrapped_count + 1
  69. return max(stacklevel, 2)
  70. def count_global_wrappers(func):
  71. """Count the total number of times a function as been wrapped globally.
  72. Similar to :func:`count_inner_wrappers`, this counts the number of times
  73. `func` has been wrapped. However, this function doesn't start counting
  74. from `func` but instead tries to access the "global representation" of
  75. `func`. This means that you could use this function from inside a wrapper
  76. that was applied first, and still count wrappers that were applied on
  77. top of it afterwards.
  78. E.g., `func` might be wrapped by multiple decorators that emit
  79. warnings. In that case, calling this function in the inner-most decorator
  80. will still return the total count of wrappers.
  81. Parameters
  82. ----------
  83. func : callable
  84. The callable of which to determine the number of wrappers. Can be a
  85. function or method of a class.
  86. Returns
  87. -------
  88. count : int
  89. The number of times `func` has been wrapped.
  90. See Also
  91. --------
  92. count_inner_wrappers
  93. """
  94. if "<locals>" in func.__qualname__:
  95. msg = (
  96. "Cannot determine stacklevel of a function defined in another "
  97. "function's local namespace. Set the stacklevel manually."
  98. )
  99. raise ValueError(msg)
  100. first_name, *other = func.__qualname__.split(".")
  101. global_func = func.__globals__.get(first_name, func)
  102. # Account for `func` being a method, in which case it's an attribute of
  103. # what we got from `func.__globals__`
  104. for part in other:
  105. global_func = getattr(global_func, part, global_func)
  106. count = count_inner_wrappers(global_func)
  107. assert count >= 0
  108. return count
  109. class change_default_value:
  110. """Decorator for changing the default value of an argument.
  111. Parameters
  112. ----------
  113. arg_name : str
  114. The name of the argument to be updated.
  115. new_value : any
  116. The argument new value.
  117. changed_version : str
  118. The package version in which the change will be introduced.
  119. warning_msg : str
  120. Optional warning message. If None, a generic warning message
  121. is used.
  122. stacklevel : {None, int}, optional
  123. If None, the decorator attempts to detect the appropriate stacklevel for the
  124. deprecation warning automatically. This can fail, e.g., due to
  125. decorating a closure, in which case you can set the stacklevel manually
  126. here. The outermost decorator should have stacklevel 2, the next inner
  127. one stacklevel 3, etc.
  128. """
  129. def __init__(
  130. self, arg_name, *, new_value, changed_version, warning_msg=None, stacklevel=None
  131. ):
  132. self.arg_name = arg_name
  133. self.new_value = new_value
  134. self.warning_msg = warning_msg
  135. self.changed_version = changed_version
  136. self.stacklevel = stacklevel
  137. def __call__(self, func):
  138. parameters = inspect.signature(func).parameters
  139. arg_idx = list(parameters.keys()).index(self.arg_name)
  140. old_value = parameters[self.arg_name].default
  141. if self.warning_msg is None:
  142. self.warning_msg = (
  143. f'The new recommended value for {self.arg_name} is '
  144. f'{self.new_value}. Until version {self.changed_version}, '
  145. f'the default {self.arg_name} value is {old_value}. '
  146. f'From version {self.changed_version}, the {self.arg_name} '
  147. f'default value will be {self.new_value}. To avoid '
  148. f'this warning, please explicitly set {self.arg_name} value.'
  149. )
  150. @functools.wraps(func)
  151. def fixed_func(*args, **kwargs):
  152. if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
  153. stacklevel = (
  154. self.stacklevel
  155. if self.stacklevel is not None
  156. else _warning_stacklevel(func)
  157. )
  158. # warn that arg_name default value changed:
  159. warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
  160. return func(*args, **kwargs)
  161. return fixed_func
  162. class PatchClassRepr(type):
  163. """Control class representations in rendered signatures."""
  164. def __repr__(cls):
  165. return f"<{cls.__name__}>"
  166. class DEPRECATED(metaclass=PatchClassRepr):
  167. """Signal value to help with deprecating parameters that use None.
  168. This is a proxy object, used to signal that a parameter has not been set.
  169. This is useful if ``None`` is already used for a different purpose or just
  170. to highlight a deprecated parameter in the signature.
  171. """
  172. class deprecate_parameter:
  173. """Deprecate a parameter of a function.
  174. Parameters
  175. ----------
  176. deprecated_name : str
  177. The name of the deprecated parameter.
  178. start_version : str
  179. The package version in which the warning was introduced.
  180. stop_version : str
  181. The package version in which the warning will be replaced by
  182. an error / the deprecation is completed.
  183. template : str, optional
  184. If given, this message template is used instead of the default one.
  185. new_name : str, optional
  186. If given, the default message will recommend the new parameter name and an
  187. error will be raised if the user uses both old and new names for the
  188. same parameter.
  189. modify_docstring : bool, optional
  190. If the wrapped function has a docstring, add the deprecated parameters
  191. to the "Other Parameters" section.
  192. stacklevel : {None, int}, optional
  193. If None, the decorator attempts to detect the appropriate stacklevel for the
  194. deprecation warning automatically. This can fail, e.g., due to
  195. decorating a closure, in which case you can set the stacklevel manually
  196. here. The outermost decorator should have stacklevel 2, the next inner
  197. one stacklevel 3, etc.
  198. Notes
  199. -----
  200. Assign `DEPRECATED` as the new default value for the deprecated parameter.
  201. This marks the status of the parameter also in the signature and rendered
  202. HTML docs.
  203. This decorator can be stacked to deprecate more than one parameter.
  204. Examples
  205. --------
  206. >>> from skimage._shared.utils import deprecate_parameter, DEPRECATED
  207. >>> @deprecate_parameter(
  208. ... "b", new_name="c", start_version="0.1", stop_version="0.3"
  209. ... )
  210. ... def foo(a, b=DEPRECATED, *, c=None):
  211. ... return a, c
  212. Calling ``foo(1, b=2)`` will warn with::
  213. FutureWarning: Parameter `b` is deprecated since version 0.1 and will
  214. be removed in 0.3 (or later). To avoid this warning, please use the
  215. parameter `c` instead. For more details, see the documentation of
  216. `foo`.
  217. """
  218. DEPRECATED = DEPRECATED # Make signal value accessible for convenience
  219. remove_parameter_template = (
  220. "Parameter `{deprecated_name}` is deprecated since version "
  221. "{deprecated_version} and will be removed in {changed_version} (or "
  222. "later). To avoid this warning, please do not use the parameter "
  223. "`{deprecated_name}`. For more details, see the documentation of "
  224. "`{func_name}`."
  225. )
  226. replace_parameter_template = (
  227. "Parameter `{deprecated_name}` is deprecated since version "
  228. "{deprecated_version} and will be removed in {changed_version} (or "
  229. "later). To avoid this warning, please use the parameter `{new_name}` "
  230. "instead. For more details, see the documentation of `{func_name}`."
  231. )
  232. def __init__(
  233. self,
  234. deprecated_name,
  235. *,
  236. start_version,
  237. stop_version,
  238. template=None,
  239. new_name=None,
  240. modify_docstring=True,
  241. stacklevel=None,
  242. ):
  243. self.deprecated_name = deprecated_name
  244. self.new_name = new_name
  245. self.template = template
  246. self.start_version = start_version
  247. self.stop_version = stop_version
  248. self.modify_docstring = modify_docstring
  249. self.stacklevel = stacklevel
  250. def __call__(self, func):
  251. parameters = inspect.signature(func).parameters
  252. try:
  253. deprecated_idx = list(parameters.keys()).index(self.deprecated_name)
  254. except ValueError as e:
  255. raise ValueError(f"{self.deprecated_name!r} not in parameters") from e
  256. new_idx = False
  257. if self.new_name:
  258. try:
  259. new_idx = list(parameters.keys()).index(self.new_name)
  260. except ValueError as e:
  261. raise ValueError(f"{self.new_name!r} not in parameters") from e
  262. if parameters[self.deprecated_name].default is not DEPRECATED:
  263. raise RuntimeError(
  264. f"Expected `{self.deprecated_name}` to have the value {DEPRECATED!r} "
  265. f"to indicate its status in the rendered signature."
  266. )
  267. if self.template is not None:
  268. template = self.template
  269. elif self.new_name is not None:
  270. template = self.replace_parameter_template
  271. else:
  272. template = self.remove_parameter_template
  273. warning_message = template.format(
  274. deprecated_name=self.deprecated_name,
  275. deprecated_version=self.start_version,
  276. changed_version=self.stop_version,
  277. func_name=func.__qualname__,
  278. new_name=self.new_name,
  279. )
  280. @functools.wraps(func)
  281. def fixed_func(*args, **kwargs):
  282. deprecated_value = DEPRECATED
  283. new_value = DEPRECATED
  284. # Extract value of deprecated parameter
  285. if len(args) > deprecated_idx:
  286. deprecated_value = args[deprecated_idx]
  287. # Overwrite old with DEPRECATED if replacement exists
  288. if self.new_name is not None:
  289. args = (
  290. args[:deprecated_idx]
  291. + (DEPRECATED,)
  292. + args[deprecated_idx + 1 :]
  293. )
  294. if self.deprecated_name in kwargs.keys():
  295. deprecated_value = kwargs[self.deprecated_name]
  296. # Overwrite old with DEPRECATED if replacement exists
  297. if self.new_name is not None:
  298. kwargs[self.deprecated_name] = DEPRECATED
  299. # Extract value of new parameter (if present)
  300. if new_idx is not False and len(args) > new_idx:
  301. new_value = args[new_idx]
  302. if self.new_name and self.new_name in kwargs.keys():
  303. new_value = kwargs[self.new_name]
  304. if deprecated_value is not DEPRECATED:
  305. stacklevel = (
  306. self.stacklevel
  307. if self.stacklevel is not None
  308. else _warning_stacklevel(func)
  309. )
  310. warnings.warn(
  311. warning_message, category=FutureWarning, stacklevel=stacklevel
  312. )
  313. if new_value is not DEPRECATED:
  314. raise ValueError(
  315. f"Both deprecated parameter `{self.deprecated_name}` "
  316. f"and new parameter `{self.new_name}` are used. Use "
  317. f"only the latter to avoid conflicting values."
  318. )
  319. elif self.new_name is not None:
  320. # Assign old value to new one
  321. kwargs[self.new_name] = deprecated_value
  322. return func(*args, **kwargs)
  323. if self.modify_docstring and func.__doc__ is not None:
  324. newdoc = _docstring_add_deprecated(
  325. func, {self.deprecated_name: self.new_name}, self.start_version
  326. )
  327. fixed_func.__doc__ = newdoc
  328. return fixed_func
  329. def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
  330. """Add deprecated kwarg(s) to the "Other Params" section of a docstring.
  331. Parameters
  332. ----------
  333. func : function
  334. The function whose docstring we wish to update.
  335. kwarg_mapping : dict
  336. A dict containing {old_arg: new_arg} key/value pairs, see
  337. `deprecate_parameter`.
  338. deprecated_version : str
  339. A major.minor version string specifying when old_arg was
  340. deprecated.
  341. Returns
  342. -------
  343. new_doc : str
  344. The updated docstring. Returns the original docstring if numpydoc is
  345. not available.
  346. """
  347. if func.__doc__ is None:
  348. return None
  349. try:
  350. from numpydoc.docscrape import FunctionDoc, Parameter
  351. except ImportError:
  352. # Return an unmodified docstring if numpydoc is not available.
  353. return func.__doc__
  354. Doc = FunctionDoc(func)
  355. for old_arg, new_arg in kwarg_mapping.items():
  356. desc = []
  357. if new_arg is None:
  358. desc.append(f'`{old_arg}` is deprecated.')
  359. else:
  360. desc.append(f'Deprecated in favor of `{new_arg}`.')
  361. desc += ['', f'.. deprecated:: {deprecated_version}']
  362. Doc['Other Parameters'].append(
  363. Parameter(name=old_arg, type='DEPRECATED', desc=desc)
  364. )
  365. new_docstring = str(Doc)
  366. # new_docstring will have a header starting with:
  367. #
  368. # .. function:: func.__name__
  369. #
  370. # and some additional blank lines. We strip these off below.
  371. split = new_docstring.split('\n')
  372. no_header = split[1:]
  373. while not no_header[0].strip():
  374. no_header.pop(0)
  375. # Store the initial description before any of the Parameters fields.
  376. # Usually this is a single line, but the while loop covers any case
  377. # where it is not.
  378. descr = no_header.pop(0)
  379. while no_header[0].strip():
  380. descr += '\n ' + no_header.pop(0)
  381. descr += '\n\n'
  382. # '\n ' rather than '\n' here to restore the original indentation.
  383. final_docstring = descr + '\n '.join(no_header)
  384. # strip any extra spaces from ends of lines
  385. final_docstring = '\n'.join([line.rstrip() for line in final_docstring.split('\n')])
  386. return final_docstring
  387. class FailedEstimationAccessError(AttributeError):
  388. """Error from use of failed estimation instance
  389. This error arises from attempts to use an instance of
  390. :class:`FailedEstimation`.
  391. """
  392. class FailedEstimation:
  393. """Class to indicate a failed transform estimation.
  394. The ``from_estimate`` class method of each transform type may return an
  395. instance of this class to indicate some failure in the estimation process.
  396. Parameters
  397. ----------
  398. message : str
  399. Message indicating reason for failed estimation.
  400. Attributes
  401. ----------
  402. message : str
  403. Message above.
  404. Raises
  405. ------
  406. FailedEstimationAccessError
  407. Exception raised for missing attributes or if the instance is used as a
  408. callable.
  409. """
  410. error_cls = FailedEstimationAccessError
  411. hint = (
  412. "You can check for a failed estimation by truth testing the returned "
  413. "object. For failed estimations, `bool(estimation_result)` will be `False`. "
  414. "E.g.\n\n"
  415. " if not estimation_result:\n"
  416. " raise RuntimeError(f'Failed estimation: {estimation_result}')"
  417. )
  418. def __init__(self, message):
  419. self.message = message
  420. def __bool__(self):
  421. return False
  422. def __repr__(self):
  423. return f"{type(self).__name__}({self.message!r})"
  424. def __str__(self):
  425. return self.message
  426. def __call__(self, *args, **kwargs):
  427. msg = (
  428. f'{type(self).__name__} is not callable. {self.message}\n\n'
  429. f'Hint: {self.hint}'
  430. )
  431. raise self.error_cls(msg)
  432. def __getattr__(self, name):
  433. msg = (
  434. f'{type(self).__name__} has no attribute {name!r}. {self.message}\n\n'
  435. f'Hint: {self.hint}'
  436. )
  437. raise self.error_cls(msg)
  438. @contextmanager
  439. def _ignore_deprecated_estimate_warning():
  440. """Filter warnings about the deprecated `estimate` method.
  441. Use either as decorator or context manager.
  442. """
  443. with warnings.catch_warnings():
  444. warnings.filterwarnings(
  445. action="ignore",
  446. category=FutureWarning,
  447. message="`estimate` is deprecated",
  448. module="skimage",
  449. )
  450. yield
  451. class channel_as_last_axis:
  452. """Decorator for automatically making channels axis last for all arrays.
  453. This decorator reorders axes for compatibility with functions that only
  454. support channels along the last axis. After the function call is complete
  455. the channels axis is restored back to its original position.
  456. Parameters
  457. ----------
  458. channel_arg_positions : tuple of int, optional
  459. Positional arguments at the positions specified in this tuple are
  460. assumed to be multichannel arrays. The default is to assume only the
  461. first argument to the function is a multichannel array.
  462. channel_kwarg_names : tuple of str, optional
  463. A tuple containing the names of any keyword arguments corresponding to
  464. multichannel arrays.
  465. multichannel_output : bool, optional
  466. A boolean that should be True if the output of the function is not a
  467. multichannel array and False otherwise. This decorator does not
  468. currently support the general case of functions with multiple outputs
  469. where some or all are multichannel.
  470. """
  471. def __init__(
  472. self,
  473. channel_arg_positions=(0,),
  474. channel_kwarg_names=(),
  475. multichannel_output=True,
  476. ):
  477. self.arg_positions = set(channel_arg_positions)
  478. self.kwarg_names = set(channel_kwarg_names)
  479. self.multichannel_output = multichannel_output
  480. def __call__(self, func):
  481. @functools.wraps(func)
  482. def fixed_func(*args, **kwargs):
  483. channel_axis = kwargs.get('channel_axis', None)
  484. if channel_axis is None:
  485. return func(*args, **kwargs)
  486. # TODO: convert scalars to a tuple in anticipation of eventually
  487. # supporting a tuple of channel axes. Right now, only an
  488. # integer or a single-element tuple is supported, though.
  489. if np.isscalar(channel_axis):
  490. channel_axis = (channel_axis,)
  491. if len(channel_axis) > 1:
  492. raise ValueError("only a single channel axis is currently supported")
  493. if channel_axis == (-1,) or channel_axis == -1:
  494. return func(*args, **kwargs)
  495. if self.arg_positions:
  496. new_args = []
  497. for pos, arg in enumerate(args):
  498. if pos in self.arg_positions:
  499. new_args.append(np.moveaxis(arg, channel_axis[0], -1))
  500. else:
  501. new_args.append(arg)
  502. new_args = tuple(new_args)
  503. else:
  504. new_args = args
  505. for name in self.kwarg_names:
  506. kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
  507. # now that we have moved the channels axis to the last position,
  508. # change the channel_axis argument to -1
  509. kwargs["channel_axis"] = -1
  510. # Call the function with the fixed arguments
  511. out = func(*new_args, **kwargs)
  512. if self.multichannel_output:
  513. out = np.moveaxis(out, -1, channel_axis[0])
  514. return out
  515. return fixed_func
  516. class deprecate_func:
  517. """Decorate a deprecated function and warn when it is called.
  518. Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
  519. Parameters
  520. ----------
  521. deprecated_version : str
  522. The package version when the deprecation was introduced.
  523. removed_version : str
  524. The package version in which the deprecated function will be removed.
  525. hint : str, optional
  526. A hint on how to address this deprecation,
  527. e.g., "Use `skimage.submodule.alternative_func` instead."
  528. stacklevel : {None, int}, optional
  529. If None, the decorator attempts to detect the appropriate stacklevel for the
  530. deprecation warning automatically. This can fail, e.g., due to
  531. decorating a closure, in which case you can set the stacklevel manually
  532. here. The outermost decorator should have stacklevel 2, the next inner
  533. one stacklevel 3, etc.
  534. Examples
  535. --------
  536. >>> @deprecate_func(
  537. ... deprecated_version="1.0.0",
  538. ... removed_version="1.2.0",
  539. ... hint="Use `bar` instead."
  540. ... )
  541. ... def foo():
  542. ... pass
  543. Calling ``foo`` will warn with::
  544. FutureWarning: `foo` is deprecated since version 1.0.0
  545. and will be removed in version 1.2.0. Use `bar` instead.
  546. """
  547. def __init__(
  548. self, *, deprecated_version, removed_version=None, hint=None, stacklevel=None
  549. ):
  550. self.deprecated_version = deprecated_version
  551. self.removed_version = removed_version
  552. self.hint = hint
  553. self.stacklevel = stacklevel
  554. def __call__(self, func):
  555. message = (
  556. f"`{func.__name__}` is deprecated since version {self.deprecated_version}"
  557. )
  558. if self.removed_version:
  559. message += f" and will be removed in version {self.removed_version}."
  560. if self.hint:
  561. # Prepend space and make sure it closes with "."
  562. message += f" {self.hint.rstrip('.')}."
  563. @functools.wraps(func)
  564. def wrapped(*args, **kwargs):
  565. stacklevel = (
  566. self.stacklevel
  567. if self.stacklevel is not None
  568. else _warning_stacklevel(func)
  569. )
  570. warnings.warn(message, category=FutureWarning, stacklevel=stacklevel)
  571. return func(*args, **kwargs)
  572. # modify docstring to display deprecation warning
  573. doc = f'**Deprecated:** {message}'
  574. if wrapped.__doc__ is None:
  575. wrapped.__doc__ = doc
  576. else:
  577. wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__
  578. return wrapped
  579. def _deprecate_estimate(func, class_name=None):
  580. """Deprecate ``estimate`` method."""
  581. class_name = func.__qualname__.split('.')[0] if class_name is None else class_name
  582. return deprecate_func(
  583. deprecated_version="0.26",
  584. removed_version="2.2",
  585. hint=f"Please use `{class_name}.from_estimate` class constructor instead.",
  586. stacklevel=2,
  587. )(func)
  588. def _deprecate_inherited_estimate(cls):
  589. """Deprecate inherited ``estimate`` instance method.
  590. This needs a class decorator so we can correctly specify the class of the
  591. `from_estimate` class method in the deprecation message.
  592. """
  593. def estimate(self, *args, **kwargs):
  594. return self._estimate(*args, **kwargs) is None
  595. # The inherited method will always be wrapped by deprecator.
  596. inherited_meth = getattr(cls, 'estimate').__wrapped__
  597. estimate.__doc__ = inherited_meth.__doc__
  598. estimate.__signature__ = inspect.signature(inherited_meth)
  599. cls.estimate = _deprecate_estimate(estimate, cls.__name__)
  600. return cls
  601. def _update_from_estimate_docstring(cls):
  602. """Fix docstring for inherited ``from_estimate`` class method.
  603. Even for classes that inherit the `from_estimate` method, and do not
  604. override it, we nevertheless need to change the *docstring* of the
  605. `from_estimate` method to point the user to the current (inheriting) class,
  606. rather than the class in which the method is defined (the inherited class).
  607. This needs a class decorator so we can modify the docstring of the new
  608. class method. CPython currently does not allow us to modify class method
  609. docstrings by updating ``__doc__``.
  610. """
  611. inherited_cmeth = getattr(cls, 'from_estimate')
  612. def from_estimate(cls, *args, **kwargs):
  613. return inherited_cmeth(*args, **kwargs)
  614. inherited_class_name = inherited_cmeth.__qualname__.split('.')[-2]
  615. from_estimate.__doc__ = inherited_cmeth.__doc__.replace(
  616. inherited_class_name, cls.__name__
  617. )
  618. from_estimate.__signature__ = inspect.signature(inherited_cmeth)
  619. cls.from_estimate = classmethod(from_estimate)
  620. return cls
  621. def get_bound_method_class(m):
  622. """Return the class for a bound method."""
  623. return m.im_class if sys.version < '3' else m.__self__.__class__
  624. def safe_as_int(val, atol=1e-3):
  625. """
  626. Attempt to safely cast values to integer format.
  627. Parameters
  628. ----------
  629. val : scalar or iterable of scalars
  630. Number or container of numbers which are intended to be interpreted as
  631. integers, e.g., for indexing purposes, but which may not carry integer
  632. type.
  633. atol : float
  634. Absolute tolerance away from nearest integer to consider values in
  635. ``val`` functionally integers.
  636. Returns
  637. -------
  638. val_int : NumPy scalar or ndarray of dtype `np.int64`
  639. Returns the input value(s) coerced to dtype `np.int64` assuming all
  640. were within ``atol`` of the nearest integer.
  641. Notes
  642. -----
  643. This operation calculates ``val`` modulo 1, which returns the mantissa of
  644. all values. Then all mantissas greater than 0.5 are subtracted from one.
  645. Finally, the absolute tolerance from zero is calculated. If it is less
  646. than ``atol`` for all value(s) in ``val``, they are rounded and returned
  647. in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
  648. returned.
  649. If any value(s) are outside the specified tolerance, an informative error
  650. is raised.
  651. Examples
  652. --------
  653. >>> safe_as_int(7.0)
  654. 7
  655. >>> safe_as_int([9, 4, 2.9999999999])
  656. array([9, 4, 3])
  657. >>> safe_as_int(53.1)
  658. Traceback (most recent call last):
  659. ...
  660. ValueError: Integer argument required but received 53.1, check inputs.
  661. >>> safe_as_int(53.01, atol=0.01)
  662. 53
  663. """
  664. mod = np.asarray(val) % 1 # Extract mantissa
  665. # Check for and subtract any mod values > 0.5 from 1
  666. if mod.ndim == 0: # Scalar input, cannot be indexed
  667. if mod > 0.5:
  668. mod = 1 - mod
  669. else: # Iterable input, now ndarray
  670. mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
  671. if not np.allclose(mod, 0, atol=atol):
  672. raise ValueError(f'Integer argument required but received {val}, check inputs.')
  673. return np.round(val).astype(np.int64)
  674. def check_shape_equality(*images):
  675. """Check that all images have the same shape"""
  676. image0 = images[0]
  677. if not all(image0.shape == image.shape for image in images[1:]):
  678. raise ValueError('Input images must have the same dimensions.')
  679. return
  680. def slice_at_axis(sl, axis):
  681. """
  682. Construct tuple of slices to slice an array in the given dimension.
  683. Parameters
  684. ----------
  685. sl : slice
  686. The slice for the given dimension.
  687. axis : int
  688. The axis to which `sl` is applied. All other dimensions are left
  689. "unsliced".
  690. Returns
  691. -------
  692. sl : tuple of slices
  693. A tuple with slices matching `shape` in length.
  694. Examples
  695. --------
  696. >>> slice_at_axis(slice(None, 3, -1), 1)
  697. (slice(None, None, None), slice(None, 3, -1), Ellipsis)
  698. """
  699. return (slice(None),) * axis + (sl,) + (...,)
  700. def reshape_nd(arr, ndim, dim):
  701. """Reshape a 1D array to have n dimensions, all singletons but one.
  702. Parameters
  703. ----------
  704. arr : array, shape (N,)
  705. Input array
  706. ndim : int
  707. Number of desired dimensions of reshaped array.
  708. dim : int
  709. Which dimension/axis will not be singleton-sized.
  710. Returns
  711. -------
  712. arr_reshaped : array, shape ([1, ...], N, [1,...])
  713. View of `arr` reshaped to the desired shape.
  714. Examples
  715. --------
  716. >>> rng = np.random.default_rng()
  717. >>> arr = rng.random(7)
  718. >>> reshape_nd(arr, 2, 0).shape
  719. (7, 1)
  720. >>> reshape_nd(arr, 3, 1).shape
  721. (1, 7, 1)
  722. >>> reshape_nd(arr, 4, -1).shape
  723. (1, 1, 1, 7)
  724. """
  725. if arr.ndim != 1:
  726. raise ValueError("arr must be a 1D array")
  727. new_shape = [1] * ndim
  728. new_shape[dim] = -1
  729. return np.reshape(arr, new_shape)
  730. def check_nD(array, ndim, arg_name='image'):
  731. """
  732. Verify an array meets the desired ndims and array isn't empty.
  733. Parameters
  734. ----------
  735. array : array-like
  736. Input array to be validated
  737. ndim : int or iterable of ints
  738. Allowable ndim or ndims for the array.
  739. arg_name : str, optional
  740. The name of the array in the original function.
  741. """
  742. array = np.asanyarray(array)
  743. msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
  744. msg_empty_array = "The parameter `%s` cannot be an empty array"
  745. if isinstance(ndim, int):
  746. ndim = [ndim]
  747. if array.size == 0:
  748. raise ValueError(msg_empty_array % (arg_name))
  749. if array.ndim not in ndim:
  750. raise ValueError(
  751. msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim]))
  752. )
  753. def convert_to_float(image, preserve_range):
  754. """Convert input image to float image with the appropriate range.
  755. Parameters
  756. ----------
  757. image : ndarray
  758. Input image.
  759. preserve_range : bool
  760. Determines if the range of the image should be kept or transformed
  761. using img_as_float. Also see
  762. https://scikit-image.org/docs/dev/user_guide/data_types.html
  763. Notes
  764. -----
  765. * Input images with `float32` data type are not upcast.
  766. Returns
  767. -------
  768. image : ndarray
  769. Transformed version of the input.
  770. """
  771. if image.dtype == np.float16:
  772. return image.astype(np.float32)
  773. if preserve_range:
  774. # Convert image to double only if it is not single or double
  775. # precision float
  776. if image.dtype.char not in 'df':
  777. image = image.astype(float)
  778. else:
  779. from ..util.dtype import img_as_float
  780. image = img_as_float(image)
  781. return image
  782. def _validate_interpolation_order(image_dtype, order):
  783. """Validate and return spline interpolation's order.
  784. Parameters
  785. ----------
  786. image_dtype : dtype
  787. Image dtype.
  788. order : {None, int}, optional
  789. The order of the spline interpolation. The order has to be in the range
  790. 0-5. If ``None`` assume order 0 for Boolean images, otherwise 1. See
  791. `skimage.transform.warp` for detail.
  792. Returns
  793. -------
  794. order : int
  795. if input order is None, returns 0 if image_dtype is bool and 1
  796. otherwise. Otherwise, image_dtype is checked and input order
  797. is validated accordingly (order > 0 is not supported for bool
  798. image dtype)
  799. """
  800. if order is None:
  801. return 0 if image_dtype == bool else 1
  802. if order < 0 or order > 5:
  803. raise ValueError("Spline interpolation order has to be in the range 0-5.")
  804. if image_dtype == bool and order != 0:
  805. raise ValueError(
  806. "Input image dtype is bool. Interpolation is not defined "
  807. "with bool data type. Please set order to 0 or explicitly "
  808. "cast input image to another data type."
  809. )
  810. return order
  811. def _to_np_mode(mode):
  812. """Convert padding modes from `ndi.correlate` to `np.pad`."""
  813. mode_translation_dict = dict(nearest='edge', reflect='symmetric', mirror='reflect')
  814. if mode in mode_translation_dict:
  815. mode = mode_translation_dict[mode]
  816. return mode
  817. def _to_ndimage_mode(mode):
  818. """Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
  819. mode_translation_dict = dict(
  820. constant='constant',
  821. edge='nearest',
  822. symmetric='reflect',
  823. reflect='mirror',
  824. wrap='wrap',
  825. )
  826. if mode not in mode_translation_dict:
  827. raise ValueError(
  828. f"Unknown mode: '{mode}', or cannot translate mode. The "
  829. f"mode should be one of 'constant', 'edge', 'symmetric', "
  830. f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
  831. f"more info."
  832. )
  833. return _fix_ndimage_mode(mode_translation_dict[mode])
  834. def _fix_ndimage_mode(mode):
  835. # SciPy 1.6.0 introduced grid variants of constant and wrap which
  836. # have less surprising behavior for images. Use these when available
  837. grid_modes = {'constant': 'grid-constant', 'wrap': 'grid-wrap'}
  838. return grid_modes.get(mode, mode)
  839. new_float_type = {
  840. # preserved types
  841. np.float32().dtype.char: np.float32,
  842. np.float64().dtype.char: np.float64,
  843. np.complex64().dtype.char: np.complex64,
  844. np.complex128().dtype.char: np.complex128,
  845. # altered types
  846. np.float16().dtype.char: np.float32,
  847. 'g': np.float64, # np.float128 ; doesn't exist on windows
  848. 'G': np.complex128, # np.complex256 ; doesn't exist on windows
  849. }
  850. def _supported_float_type(input_dtype, allow_complex=False):
  851. """Return an appropriate floating-point dtype for a given dtype.
  852. float32, float64, complex64, complex128 are preserved.
  853. float16 is promoted to float32.
  854. complex256 is demoted to complex128.
  855. Other types are cast to float64.
  856. Parameters
  857. ----------
  858. input_dtype : np.dtype or tuple of np.dtype
  859. The input dtype. If a tuple of multiple dtypes is provided, each
  860. dtype is first converted to a supported floating point type and the
  861. final dtype is then determined by applying `np.result_type` on the
  862. sequence of supported floating point types.
  863. allow_complex : bool, optional
  864. If False, raise a ValueError on complex-valued inputs.
  865. Returns
  866. -------
  867. float_type : dtype
  868. Floating-point dtype for the image.
  869. """
  870. if isinstance(input_dtype, tuple):
  871. return np.result_type(*(_supported_float_type(d) for d in input_dtype))
  872. input_dtype = np.dtype(input_dtype)
  873. if not allow_complex and input_dtype.kind == 'c':
  874. raise ValueError("complex valued input is not supported")
  875. return new_float_type.get(input_dtype.char, np.float64)
  876. def identity(image, *args, **kwargs):
  877. """Returns the first argument unmodified."""
  878. return image
  879. def as_binary_ndarray(array, *, variable_name):
  880. """Return `array` as a numpy.ndarray of dtype bool.
  881. Raises
  882. ------
  883. ValueError:
  884. An error including the given `variable_name` if `array` can not be
  885. safely cast to a boolean array.
  886. """
  887. array = np.asarray(array)
  888. if array.dtype != bool:
  889. if np.any((array != 1) & (array != 0)):
  890. raise ValueError(
  891. f"{variable_name} array is not of dtype boolean or "
  892. f"contains values other than 0 and 1 so cannot be "
  893. f"safely cast to boolean array."
  894. )
  895. return np.asarray(array, dtype=bool)