arraycrop.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. The arraycrop module contains functions to crop values from the edges of an
  3. n-dimensional array.
  4. """
  5. import numpy as np
  6. from numbers import Integral
  7. __all__ = ['crop']
  8. def crop(ar, crop_width, copy=False, order='K'):
  9. """Crop array `ar` by `crop_width` along each dimension.
  10. Parameters
  11. ----------
  12. ar : array-like of rank N
  13. Input array.
  14. crop_width : {sequence, int}
  15. Number of values to remove from the edges of each axis.
  16. ``((before_1, after_1),`` ... ``(before_N, after_N))`` specifies
  17. unique crop widths at the start and end of each axis.
  18. ``((before, after),) or (before, after)`` specifies
  19. a fixed start and end crop for every axis.
  20. ``(n,)`` or ``n`` for integer ``n`` is a shortcut for
  21. before = after = ``n`` for all axes.
  22. copy : bool, optional
  23. If `True`, ensure the returned array is a contiguous copy. Normally,
  24. a crop operation will return a discontiguous view of the underlying
  25. input array.
  26. order : {'C', 'F', 'A', 'K'}, optional
  27. If ``copy==True``, control the memory layout of the copy. See
  28. ``np.copy``.
  29. Returns
  30. -------
  31. cropped : array
  32. The cropped array. If ``copy=False`` (default), this is a sliced
  33. view of the input array.
  34. """
  35. ar = np.array(ar, copy=False)
  36. if isinstance(crop_width, Integral):
  37. crops = [[crop_width, crop_width]] * ar.ndim
  38. elif isinstance(crop_width[0], Integral):
  39. if len(crop_width) == 1:
  40. crops = [[crop_width[0], crop_width[0]]] * ar.ndim
  41. elif len(crop_width) == 2:
  42. crops = [crop_width] * ar.ndim
  43. else:
  44. raise ValueError(
  45. f'crop_width has an invalid length: {len(crop_width)}\n'
  46. f'crop_width should be a sequence of N pairs, '
  47. f'a single pair, or a single integer'
  48. )
  49. elif len(crop_width) == 1:
  50. crops = [crop_width[0]] * ar.ndim
  51. elif len(crop_width) == ar.ndim:
  52. crops = crop_width
  53. else:
  54. raise ValueError(
  55. f'crop_width has an invalid length: {len(crop_width)}\n'
  56. f'crop_width should be a sequence of N pairs, '
  57. f'a single pair, or a single integer'
  58. )
  59. slices = tuple(slice(a, ar.shape[i] - b) for i, (a, b) in enumerate(crops))
  60. if copy:
  61. cropped = np.array(ar[slices], order=order, copy=True)
  62. else:
  63. cropped = ar[slices]
  64. return cropped