util.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # -*- coding: utf-8 -*-
  2. # imageio is distributed under the terms of the (new) BSD License.
  3. """
  4. Various utilities for imageio
  5. """
  6. from collections import OrderedDict
  7. import numpy as np
  8. import os
  9. import re
  10. import struct
  11. import sys
  12. import time
  13. import logging
  14. logger = logging.getLogger("imageio")
  15. IS_PYPY = "__pypy__" in sys.builtin_module_names
  16. THIS_DIR = os.path.abspath(os.path.dirname(__file__))
  17. def urlopen(*args, **kwargs):
  18. """Compatibility function for the urlopen function. Raises an
  19. RuntimeError if urlopen could not be imported (which can occur in
  20. frozen applications.
  21. """
  22. try:
  23. from urllib.request import urlopen
  24. except ImportError:
  25. raise RuntimeError("Could not import urlopen.")
  26. return urlopen(*args, **kwargs)
  27. def _precision_warn(p1, p2, extra=""):
  28. t = (
  29. "Lossy conversion from {} to {}. {} Convert image to {} prior to "
  30. "saving to suppress this warning."
  31. )
  32. logger.warning(t.format(p1, p2, extra, p2))
  33. def image_as_uint(im, bitdepth=None):
  34. """Convert the given image to uint (default: uint8)
  35. If the dtype already matches the desired format, it is returned
  36. as-is. If the image is float, and all values are between 0 and 1,
  37. the values are multiplied by np.power(2.0, bitdepth). In all other
  38. situations, the values are scaled such that the minimum value
  39. becomes 0 and the maximum value becomes np.power(2.0, bitdepth)-1
  40. (255 for 8-bit and 65535 for 16-bit).
  41. """
  42. if not bitdepth:
  43. bitdepth = 8
  44. if not isinstance(im, np.ndarray):
  45. raise ValueError("Image must be a numpy array")
  46. if bitdepth == 8:
  47. out_type = np.uint8
  48. elif bitdepth == 16:
  49. out_type = np.uint16
  50. else:
  51. raise ValueError("Bitdepth must be either 8 or 16")
  52. dtype_str1 = str(im.dtype)
  53. dtype_str2 = out_type.__name__
  54. if (im.dtype == np.uint8 and bitdepth == 8) or (
  55. im.dtype == np.uint16 and bitdepth == 16
  56. ):
  57. # Already the correct format? Return as-is
  58. return im
  59. if dtype_str1.startswith("float") and np.nanmin(im) >= 0 and np.nanmax(im) <= 1:
  60. _precision_warn(dtype_str1, dtype_str2, "Range [0, 1].")
  61. im = im.astype(np.float64) * (np.power(2.0, bitdepth) - 1) + 0.499999999
  62. elif im.dtype == np.uint16 and bitdepth == 8:
  63. _precision_warn(dtype_str1, dtype_str2, "Losing 8 bits of resolution.")
  64. im = np.right_shift(im, 8)
  65. elif im.dtype == np.uint32:
  66. _precision_warn(
  67. dtype_str1,
  68. dtype_str2,
  69. "Losing {} bits of resolution.".format(32 - bitdepth),
  70. )
  71. im = np.right_shift(im, 32 - bitdepth)
  72. elif im.dtype == np.uint64:
  73. _precision_warn(
  74. dtype_str1,
  75. dtype_str2,
  76. "Losing {} bits of resolution.".format(64 - bitdepth),
  77. )
  78. im = np.right_shift(im, 64 - bitdepth)
  79. else:
  80. mi = np.nanmin(im)
  81. ma = np.nanmax(im)
  82. if not np.isfinite(mi):
  83. raise ValueError("Minimum image value is not finite")
  84. if not np.isfinite(ma):
  85. raise ValueError("Maximum image value is not finite")
  86. if ma == mi:
  87. return im.astype(out_type)
  88. _precision_warn(dtype_str1, dtype_str2, "Range [{}, {}].".format(mi, ma))
  89. # Now make float copy before we scale
  90. im = im.astype("float64")
  91. # Scale the values between 0 and 1 then multiply by the max value
  92. im = (im - mi) / (ma - mi) * (np.power(2.0, bitdepth) - 1) + 0.499999999
  93. assert np.nanmin(im) >= 0
  94. assert np.nanmax(im) < np.power(2.0, bitdepth)
  95. return im.astype(out_type)
  96. class Array(np.ndarray):
  97. """Array(array, meta=None)
  98. A subclass of np.ndarray that has a meta attribute. Get the dictionary
  99. that contains the meta data using ``im.meta``. Convert to a plain numpy
  100. array using ``np.asarray(im)``.
  101. """
  102. def __new__(cls, array, meta=None):
  103. # Check
  104. if not isinstance(array, np.ndarray):
  105. raise ValueError("Array expects a numpy array.")
  106. if not (meta is None or isinstance(meta, dict)):
  107. raise ValueError("Array expects meta data to be a dict.")
  108. # Convert and return
  109. meta = meta if meta is not None else getattr(array, "meta", {})
  110. try:
  111. ob = array.view(cls)
  112. except AttributeError: # pragma: no cover
  113. # Just return the original; no metadata on the array in Pypy!
  114. return array
  115. ob._copy_meta(meta)
  116. return ob
  117. def _copy_meta(self, meta):
  118. """Make a 2-level deep copy of the meta dictionary."""
  119. self._meta = Dict()
  120. for key, val in meta.items():
  121. if isinstance(val, dict):
  122. val = Dict(val) # Copy this level
  123. self._meta[key] = val
  124. @property
  125. def meta(self):
  126. """The dict with the meta data of this image."""
  127. return self._meta
  128. def __array_finalize__(self, ob):
  129. """So the meta info is maintained when doing calculations with
  130. the array.
  131. """
  132. if isinstance(ob, Array):
  133. self._copy_meta(ob.meta)
  134. else:
  135. self._copy_meta({})
  136. def __array_wrap__(self, out, context=None):
  137. """So that we return a native numpy array (or scalar) when a
  138. reducting ufunc is applied (such as sum(), std(), etc.)
  139. """
  140. if not out.shape:
  141. return out.dtype.type(out) # Scalar
  142. elif out.shape != self.shape:
  143. return out.view(type=np.ndarray)
  144. elif not isinstance(out, Array):
  145. return Array(out, self.meta)
  146. else:
  147. return out # Type Array
  148. Image = Array # Alias for backwards compatibility
  149. def asarray(a):
  150. """Pypy-safe version of np.asarray. Pypy's np.asarray consumes a
  151. *lot* of memory if the given array is an ndarray subclass. This
  152. function does not.
  153. """
  154. if isinstance(a, np.ndarray):
  155. if IS_PYPY: # pragma: no cover
  156. a = a.copy() # pypy has issues with base views
  157. plain = a.view(type=np.ndarray)
  158. return plain
  159. return np.asarray(a)
  160. class Dict(OrderedDict):
  161. """A dict in which the keys can be get and set as if they were
  162. attributes. Very convenient in combination with autocompletion.
  163. This Dict still behaves as much as possible as a normal dict, and
  164. keys can be anything that are otherwise valid keys. However,
  165. keys that are not valid identifiers or that are names of the dict
  166. class (such as 'items' and 'copy') cannot be get/set as attributes.
  167. """
  168. __reserved_names__ = dir(OrderedDict()) # Also from OrderedDict
  169. __pure_names__ = dir(dict())
  170. def __getattribute__(self, key):
  171. try:
  172. return object.__getattribute__(self, key)
  173. except AttributeError:
  174. if key in self:
  175. return self[key]
  176. else:
  177. raise
  178. def __setattr__(self, key, val):
  179. if key in Dict.__reserved_names__:
  180. # Either let OrderedDict do its work, or disallow
  181. if key not in Dict.__pure_names__:
  182. return OrderedDict.__setattr__(self, key, val)
  183. else:
  184. raise AttributeError(
  185. "Reserved name, this key can only "
  186. + "be set via ``d[%r] = X``" % key
  187. )
  188. else:
  189. # if isinstance(val, dict): val = Dict(val) -> no, makes a copy!
  190. self[key] = val
  191. def __dir__(self):
  192. def isidentifier(x):
  193. return bool(re.match(r"[a-z_]\w*$", x, re.I))
  194. names = [k for k in self.keys() if (isinstance(k, str) and isidentifier(k))]
  195. return Dict.__reserved_names__ + names
  196. class BaseProgressIndicator(object):
  197. """BaseProgressIndicator(name)
  198. A progress indicator helps display the progress of a task to the
  199. user. Progress can be pending, running, finished or failed.
  200. Each task has:
  201. * a name - a short description of what needs to be done.
  202. * an action - the current action in performing the task (e.g. a subtask)
  203. * progress - how far the task is completed
  204. * max - max number of progress units. If 0, the progress is indefinite
  205. * unit - the units in which the progress is counted
  206. * status - 0: pending, 1: in progress, 2: finished, 3: failed
  207. This class defines an abstract interface. Subclasses should implement
  208. _start, _stop, _update_progress(progressText), _write(message).
  209. """
  210. def __init__(self, name):
  211. self._name = name
  212. self._action = ""
  213. self._unit = ""
  214. self._max = 0
  215. self._status = 0
  216. self._last_progress_update = 0
  217. def start(self, action="", unit="", max=0):
  218. """start(action='', unit='', max=0)
  219. Start the progress. Optionally specify an action, a unit,
  220. and a maximum progress value.
  221. """
  222. if self._status == 1:
  223. self.finish()
  224. self._action = action
  225. self._unit = unit
  226. self._max = max
  227. #
  228. self._progress = 0
  229. self._status = 1
  230. self._start()
  231. def status(self):
  232. """status()
  233. Get the status of the progress - 0: pending, 1: in progress,
  234. 2: finished, 3: failed
  235. """
  236. return self._status
  237. def set_progress(self, progress=0, force=False):
  238. """set_progress(progress=0, force=False)
  239. Set the current progress. To avoid unnecessary progress updates
  240. this will only have a visual effect if the time since the last
  241. update is > 0.1 seconds, or if force is True.
  242. """
  243. self._progress = progress
  244. # Update or not?
  245. if not (force or (time.time() - self._last_progress_update > 0.1)):
  246. return
  247. self._last_progress_update = time.time()
  248. # Compose new string
  249. unit = self._unit or ""
  250. progressText = ""
  251. if unit == "%":
  252. progressText = "%2.1f%%" % progress
  253. elif self._max > 0:
  254. percent = 100 * float(progress) / self._max
  255. progressText = "%i/%i %s (%2.1f%%)" % (progress, self._max, unit, percent)
  256. elif progress > 0:
  257. if isinstance(progress, float):
  258. progressText = "%0.4g %s" % (progress, unit)
  259. else:
  260. progressText = "%i %s" % (progress, unit)
  261. # Update
  262. self._update_progress(progressText)
  263. def increase_progress(self, extra_progress):
  264. """increase_progress(extra_progress)
  265. Increase the progress by a certain amount.
  266. """
  267. self.set_progress(self._progress + extra_progress)
  268. def finish(self, message=None):
  269. """finish(message=None)
  270. Finish the progress, optionally specifying a message. This will
  271. not set the progress to the maximum.
  272. """
  273. self.set_progress(self._progress, True) # fore update
  274. self._status = 2
  275. self._stop()
  276. if message is not None:
  277. self._write(message)
  278. def fail(self, message=None):
  279. """fail(message=None)
  280. Stop the progress with a failure, optionally specifying a message.
  281. """
  282. self.set_progress(self._progress, True) # fore update
  283. self._status = 3
  284. self._stop()
  285. message = "FAIL " + (message or "")
  286. self._write(message)
  287. def write(self, message):
  288. """write(message)
  289. Write a message during progress (such as a warning).
  290. """
  291. if self.__class__ == BaseProgressIndicator:
  292. # When this class is used as a dummy, print explicit message
  293. print(message)
  294. else:
  295. return self._write(message)
  296. # Implementing classes should implement these
  297. def _start(self):
  298. pass
  299. def _stop(self):
  300. pass
  301. def _update_progress(self, progressText):
  302. pass
  303. def _write(self, message):
  304. pass
  305. class StdoutProgressIndicator(BaseProgressIndicator):
  306. """StdoutProgressIndicator(name)
  307. A progress indicator that shows the progress in stdout. It
  308. assumes that the tty can appropriately deal with backspace
  309. characters.
  310. """
  311. def _start(self):
  312. self._chars_prefix, self._chars = "", ""
  313. # Write message
  314. if self._action:
  315. self._chars_prefix = "%s (%s): " % (self._name, self._action)
  316. else:
  317. self._chars_prefix = "%s: " % self._name
  318. sys.stdout.write(self._chars_prefix)
  319. sys.stdout.flush()
  320. def _update_progress(self, progressText):
  321. # If progress is unknown, at least make something move
  322. if not progressText:
  323. i1, i2, i3, i4 = "-\\|/"
  324. M = {i1: i2, i2: i3, i3: i4, i4: i1}
  325. progressText = M.get(self._chars, i1)
  326. # Store new string and write
  327. delChars = "\b" * len(self._chars)
  328. self._chars = progressText
  329. sys.stdout.write(delChars + self._chars)
  330. sys.stdout.flush()
  331. def _stop(self):
  332. self._chars = self._chars_prefix = ""
  333. sys.stdout.write("\n")
  334. sys.stdout.flush()
  335. def _write(self, message):
  336. # Write message
  337. delChars = "\b" * len(self._chars_prefix + self._chars)
  338. sys.stdout.write(delChars + " " + message + "\n")
  339. # Reprint progress text
  340. sys.stdout.write(self._chars_prefix + self._chars)
  341. sys.stdout.flush()
  342. # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
  343. def appdata_dir(appname=None, roaming=False):
  344. """appdata_dir(appname=None, roaming=False)
  345. Get the path to the application directory, where applications are allowed
  346. to write user specific files (e.g. configurations). For non-user specific
  347. data, consider using common_appdata_dir().
  348. If appname is given, a subdir is appended (and created if necessary).
  349. If roaming is True, will prefer a roaming directory (Windows Vista/7).
  350. """
  351. # Define default user directory
  352. userDir = os.getenv("IMAGEIO_USERDIR", None)
  353. if userDir is None:
  354. userDir = os.path.expanduser("~")
  355. if not os.path.isdir(userDir): # pragma: no cover
  356. userDir = "/var/tmp" # issue #54
  357. # Get system app data dir
  358. path = None
  359. if sys.platform.startswith("win"):
  360. path1, path2 = os.getenv("LOCALAPPDATA"), os.getenv("APPDATA")
  361. path = (path2 or path1) if roaming else (path1 or path2)
  362. elif sys.platform.startswith("darwin"):
  363. path = os.path.join(userDir, "Library", "Application Support")
  364. # On Linux and as fallback
  365. if not (path and os.path.isdir(path)):
  366. path = userDir
  367. # Maybe we should store things local to the executable (in case of a
  368. # portable distro or a frozen application that wants to be portable)
  369. prefix = sys.prefix
  370. if getattr(sys, "frozen", None):
  371. prefix = os.path.abspath(os.path.dirname(sys.executable))
  372. for reldir in ("settings", "../settings"):
  373. localpath = os.path.abspath(os.path.join(prefix, reldir))
  374. if os.path.isdir(localpath): # pragma: no cover
  375. try:
  376. open(os.path.join(localpath, "test.write"), "wb").close()
  377. os.remove(os.path.join(localpath, "test.write"))
  378. except IOError:
  379. pass # We cannot write in this directory
  380. else:
  381. path = localpath
  382. break
  383. # Get path specific for this app
  384. if appname:
  385. if path == userDir:
  386. appname = "." + appname.lstrip(".") # Make it a hidden directory
  387. path = os.path.join(path, appname)
  388. if not os.path.isdir(path): # pragma: no cover
  389. os.makedirs(path, exist_ok=True)
  390. # Done
  391. return path
  392. def resource_dirs():
  393. """resource_dirs()
  394. Get a list of directories where imageio resources may be located.
  395. The first directory in this list is the "resources" directory in
  396. the package itself. The second directory is the appdata directory
  397. (~/.imageio on Linux). The list further contains the application
  398. directory (for frozen apps), and may include additional directories
  399. in the future.
  400. """
  401. dirs = [resource_package_dir()]
  402. # Resource dir baked in the package.
  403. # Appdata directory
  404. try:
  405. dirs.append(appdata_dir("imageio"))
  406. except Exception: # pragma: no cover
  407. pass # The home dir may not be writable
  408. # Directory where the app is located (mainly for frozen apps)
  409. if getattr(sys, "frozen", None):
  410. dirs.append(os.path.abspath(os.path.dirname(sys.executable)))
  411. elif sys.path and sys.path[0]:
  412. dirs.append(os.path.abspath(sys.path[0]))
  413. return dirs
  414. def resource_package_dir():
  415. """package_dir
  416. Get the resources directory in the imageio package installation
  417. directory.
  418. Notes
  419. -----
  420. This is a convenience method that is used by `resource_dirs` and
  421. imageio entry point scripts.
  422. """
  423. import importlib.resources
  424. return str(importlib.resources.files("imageio") / "resources")
  425. def get_platform():
  426. """get_platform()
  427. Get a string that specifies the platform more specific than
  428. sys.platform does. The result can be: linux32, linux64, win32,
  429. win64, osx32, osx64. Other platforms may be added in the future.
  430. """
  431. # Get platform
  432. if sys.platform.startswith("linux"):
  433. plat = "linux%i"
  434. elif sys.platform.startswith("win"):
  435. plat = "win%i"
  436. elif sys.platform.startswith("darwin"):
  437. plat = "osx%i"
  438. elif sys.platform.startswith("freebsd"):
  439. plat = "freebsd%i"
  440. else: # pragma: no cover
  441. return None
  442. return plat % (struct.calcsize("P") * 8) # 32 or 64 bits
  443. def has_module(module_name):
  444. """Check to see if a python module is available."""
  445. import importlib
  446. name_parts = module_name.split(".")
  447. for i in range(len(name_parts)):
  448. if importlib.util.find_spec(".".join(name_parts[: i + 1])) is None:
  449. return False
  450. return True