_bsdf.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. #!/usr/bin/env python
  2. # This file is distributed under the terms of the 2-clause BSD License.
  3. # Copyright (c) 2017-2018, Almar Klein
  4. """
  5. Python implementation of the Binary Structured Data Format (BSDF).
  6. BSDF is a binary format for serializing structured (scientific) data.
  7. See http://bsdf.io for more information.
  8. This is the reference implementation, which is relatively relatively
  9. sophisticated, providing e.g. lazy loading of blobs and streamed
  10. reading/writing. A simpler Python implementation is available as
  11. ``bsdf_lite.py``.
  12. This module has no dependencies and works on Python 2.7 and 3.4+.
  13. Note: on Legacy Python (Python 2.7), non-Unicode strings are encoded as bytes.
  14. """
  15. # todo: in 2020, remove six stuff, __future__ and _isidentifier
  16. # todo: in 2020, remove 'utf-8' args to encode/decode; it's faster
  17. from __future__ import absolute_import, division, print_function
  18. import bz2
  19. import hashlib
  20. import logging
  21. import os
  22. import re
  23. import struct
  24. import sys
  25. import types
  26. import zlib
  27. from io import BytesIO
  28. logger = logging.getLogger(__name__)
  29. # Notes on versioning: the major and minor numbers correspond to the
  30. # BSDF format version. The major number if increased when backward
  31. # incompatible changes are introduced. An implementation must raise an
  32. # exception when the file being read has a higher major version. The
  33. # minor number is increased when new backward compatible features are
  34. # introduced. An implementation must display a warning when the file
  35. # being read has a higher minor version. The patch version is increased
  36. # for subsequent releases of the implementation.
  37. VERSION = 2, 1, 2
  38. __version__ = ".".join(str(i) for i in VERSION)
  39. # %% The encoder and decoder implementation
  40. # From six.py
  41. PY3 = sys.version_info[0] >= 3
  42. if PY3:
  43. text_type = str
  44. string_types = str
  45. unicode_types = str
  46. integer_types = int
  47. classtypes = type
  48. else: # pragma: no cover
  49. logging.basicConfig() # avoid "no handlers found" error
  50. text_type = unicode # noqa
  51. string_types = basestring # noqa
  52. unicode_types = unicode # noqa
  53. integer_types = (int, long) # noqa
  54. classtypes = type, types.ClassType
  55. # Shorthands
  56. spack = struct.pack
  57. strunpack = struct.unpack
  58. def lencode(x):
  59. """Encode an unsigned integer into a variable sized blob of bytes."""
  60. # We could support 16 bit and 32 bit as well, but the gain is low, since
  61. # 9 bytes for collections with over 250 elements is marginal anyway.
  62. if x <= 250:
  63. return spack("<B", x)
  64. # elif x < 65536:
  65. # return spack('<BH', 251, x)
  66. # elif x < 4294967296:
  67. # return spack('<BI', 252, x)
  68. else:
  69. return spack("<BQ", 253, x)
  70. # Include len decoder for completeness; we've inlined it for performance.
  71. def lendecode(f):
  72. """Decode an unsigned integer from a file."""
  73. n = strunpack("<B", f.read(1))[0]
  74. if n == 253:
  75. n = strunpack("<Q", f.read(8))[0] # noqa
  76. return n
  77. def encode_type_id(b, ext_id):
  78. """Encode the type identifier, with or without extension id."""
  79. if ext_id is not None:
  80. bb = ext_id.encode("UTF-8")
  81. return b.upper() + lencode(len(bb)) + bb # noqa
  82. else:
  83. return b # noqa
  84. def _isidentifier(s): # pragma: no cover
  85. """Use of str.isidentifier() for Legacy Python, but slower."""
  86. # http://stackoverflow.com/questions/2544972/
  87. return (
  88. isinstance(s, string_types)
  89. and re.match(r"^\w+$", s, re.UNICODE)
  90. and re.match(r"^[0-9]", s) is None
  91. )
  92. class BsdfSerializer(object):
  93. """Instances of this class represent a BSDF encoder/decoder.
  94. It acts as a placeholder for a set of extensions and encoding/decoding
  95. options. Use this to predefine extensions and options for high
  96. performance encoding/decoding. For general use, see the functions
  97. `save()`, `encode()`, `load()`, and `decode()`.
  98. This implementation of BSDF supports streaming lists (keep adding
  99. to a list after writing the main file), lazy loading of blobs, and
  100. in-place editing of blobs (for streams opened with a+).
  101. Options for encoding:
  102. * compression (int or str): ``0`` or "no" for no compression (default),
  103. ``1`` or "zlib" for Zlib compression (same as zip files and PNG), and
  104. ``2`` or "bz2" for Bz2 compression (more compact but slower writing).
  105. Note that some BSDF implementations (e.g. JavaScript) may not support
  106. compression.
  107. * use_checksum (bool): whether to include a checksum with binary blobs.
  108. * float64 (bool): Whether to write floats as 64 bit (default) or 32 bit.
  109. Options for decoding:
  110. * load_streaming (bool): if True, and the final object in the structure was
  111. a stream, will make it available as a stream in the decoded object.
  112. * lazy_blob (bool): if True, bytes are represented as Blob objects that can
  113. be used to lazily access the data, and also overwrite the data if the
  114. file is open in a+ mode.
  115. """
  116. def __init__(self, extensions=None, **options):
  117. self._extensions = {} # name -> extension
  118. self._extensions_by_cls = {} # cls -> (name, extension.encode)
  119. if extensions is None:
  120. extensions = standard_extensions
  121. for extension in extensions:
  122. self.add_extension(extension)
  123. self._parse_options(**options)
  124. def _parse_options(
  125. self,
  126. compression=0,
  127. use_checksum=False,
  128. float64=True,
  129. load_streaming=False,
  130. lazy_blob=False,
  131. ):
  132. # Validate compression
  133. if isinstance(compression, string_types):
  134. m = {"no": 0, "zlib": 1, "bz2": 2}
  135. compression = m.get(compression.lower(), compression)
  136. if compression not in (0, 1, 2):
  137. raise TypeError("Compression must be 0, 1, 2, " '"no", "zlib", or "bz2"')
  138. self._compression = compression
  139. # Other encoding args
  140. self._use_checksum = bool(use_checksum)
  141. self._float64 = bool(float64)
  142. # Decoding args
  143. self._load_streaming = bool(load_streaming)
  144. self._lazy_blob = bool(lazy_blob)
  145. def add_extension(self, extension_class):
  146. """Add an extension to this serializer instance, which must be
  147. a subclass of Extension. Can be used as a decorator.
  148. """
  149. # Check class
  150. if not (
  151. isinstance(extension_class, type) and issubclass(extension_class, Extension)
  152. ):
  153. raise TypeError("add_extension() expects a Extension class.")
  154. extension = extension_class()
  155. # Get name
  156. name = extension.name
  157. if not isinstance(name, str):
  158. raise TypeError("Extension name must be str.")
  159. if len(name) == 0 or len(name) > 250:
  160. raise NameError(
  161. "Extension names must be nonempty and shorter " "than 251 chars."
  162. )
  163. if name in self._extensions:
  164. logger.warning(
  165. 'BSDF warning: overwriting extension "%s", '
  166. "consider removing first" % name
  167. )
  168. # Get classes
  169. cls = extension.cls
  170. if not cls:
  171. clss = []
  172. elif isinstance(cls, (tuple, list)):
  173. clss = cls
  174. else:
  175. clss = [cls]
  176. for cls in clss:
  177. if not isinstance(cls, classtypes):
  178. raise TypeError("Extension classes must be types.")
  179. # Store
  180. for cls in clss:
  181. self._extensions_by_cls[cls] = name, extension.encode
  182. self._extensions[name] = extension
  183. return extension_class
  184. def remove_extension(self, name):
  185. """Remove a converted by its unique name."""
  186. if not isinstance(name, str):
  187. raise TypeError("Extension name must be str.")
  188. if name in self._extensions:
  189. self._extensions.pop(name)
  190. for cls in list(self._extensions_by_cls.keys()):
  191. if self._extensions_by_cls[cls][0] == name:
  192. self._extensions_by_cls.pop(cls)
  193. def _encode(self, f, value, streams, ext_id):
  194. """Main encoder function."""
  195. x = encode_type_id
  196. if value is None:
  197. f.write(x(b"v", ext_id)) # V for void
  198. elif value is True:
  199. f.write(x(b"y", ext_id)) # Y for yes
  200. elif value is False:
  201. f.write(x(b"n", ext_id)) # N for no
  202. elif isinstance(value, integer_types):
  203. if -32768 <= value <= 32767:
  204. f.write(x(b"h", ext_id) + spack("h", value)) # H for ...
  205. else:
  206. f.write(x(b"i", ext_id) + spack("<q", value)) # I for int
  207. elif isinstance(value, float):
  208. if self._float64:
  209. f.write(x(b"d", ext_id) + spack("<d", value)) # D for double
  210. else:
  211. f.write(x(b"f", ext_id) + spack("<f", value)) # f for float
  212. elif isinstance(value, unicode_types):
  213. bb = value.encode("UTF-8")
  214. f.write(x(b"s", ext_id) + lencode(len(bb))) # S for str
  215. f.write(bb)
  216. elif isinstance(value, (list, tuple)):
  217. f.write(x(b"l", ext_id) + lencode(len(value))) # L for list
  218. for v in value:
  219. self._encode(f, v, streams, None)
  220. elif isinstance(value, dict):
  221. f.write(x(b"m", ext_id) + lencode(len(value))) # M for mapping
  222. for key, v in value.items():
  223. if PY3:
  224. assert key.isidentifier() # faster
  225. else: # pragma: no cover
  226. assert _isidentifier(key)
  227. # yield ' ' * indent + key
  228. name_b = key.encode("UTF-8")
  229. f.write(lencode(len(name_b)))
  230. f.write(name_b)
  231. self._encode(f, v, streams, None)
  232. elif isinstance(value, bytes):
  233. f.write(x(b"b", ext_id)) # B for blob
  234. blob = Blob(
  235. value, compression=self._compression, use_checksum=self._use_checksum
  236. )
  237. blob._to_file(f) # noqa
  238. elif isinstance(value, Blob):
  239. f.write(x(b"b", ext_id)) # B for blob
  240. value._to_file(f) # noqa
  241. elif isinstance(value, BaseStream):
  242. # Initialize the stream
  243. if value.mode != "w":
  244. raise ValueError("Cannot serialize a read-mode stream.")
  245. elif isinstance(value, ListStream):
  246. f.write(x(b"l", ext_id) + spack("<BQ", 255, 0)) # L for list
  247. else:
  248. raise TypeError("Only ListStream is supported")
  249. # Mark this as *the* stream, and activate the stream.
  250. # The save() function verifies this is the last written object.
  251. if len(streams) > 0:
  252. raise ValueError("Can only have one stream per file.")
  253. streams.append(value)
  254. value._activate(f, self._encode, self._decode) # noqa
  255. else:
  256. if ext_id is not None:
  257. raise ValueError(
  258. "Extension %s wronfully encodes object to another "
  259. "extension object (though it may encode to a list/dict "
  260. "that contains other extension objects)." % ext_id
  261. )
  262. # Try if the value is of a type we know
  263. ex = self._extensions_by_cls.get(value.__class__, None)
  264. # Maybe its a subclass of a type we know
  265. if ex is None:
  266. for name, c in self._extensions.items():
  267. if c.match(self, value):
  268. ex = name, c.encode
  269. break
  270. else:
  271. ex = None
  272. # Success or fail
  273. if ex is not None:
  274. ext_id2, extension_encode = ex
  275. self._encode(f, extension_encode(self, value), streams, ext_id2)
  276. else:
  277. t = (
  278. "Class %r is not a valid base BSDF type, nor is it "
  279. "handled by an extension."
  280. )
  281. raise TypeError(t % value.__class__.__name__)
  282. def _decode(self, f):
  283. """Main decoder function."""
  284. # Get value
  285. char = f.read(1)
  286. c = char.lower()
  287. # Conversion (uppercase value identifiers signify converted values)
  288. if not char:
  289. raise EOFError()
  290. elif char != c:
  291. n = strunpack("<B", f.read(1))[0]
  292. # if n == 253: n = strunpack('<Q', f.read(8))[0] # noqa - noneed
  293. ext_id = f.read(n).decode("UTF-8")
  294. else:
  295. ext_id = None
  296. if c == b"v":
  297. value = None
  298. elif c == b"y":
  299. value = True
  300. elif c == b"n":
  301. value = False
  302. elif c == b"h":
  303. value = strunpack("<h", f.read(2))[0]
  304. elif c == b"i":
  305. value = strunpack("<q", f.read(8))[0]
  306. elif c == b"f":
  307. value = strunpack("<f", f.read(4))[0]
  308. elif c == b"d":
  309. value = strunpack("<d", f.read(8))[0]
  310. elif c == b"s":
  311. n_s = strunpack("<B", f.read(1))[0]
  312. if n_s == 253:
  313. n_s = strunpack("<Q", f.read(8))[0] # noqa
  314. value = f.read(n_s).decode("UTF-8")
  315. elif c == b"l":
  316. n = strunpack("<B", f.read(1))[0]
  317. if n >= 254:
  318. # Streaming
  319. closed = n == 254
  320. n = strunpack("<Q", f.read(8))[0]
  321. if self._load_streaming:
  322. value = ListStream(n if closed else "r")
  323. value._activate(f, self._encode, self._decode) # noqa
  324. elif closed:
  325. value = [self._decode(f) for i in range(n)]
  326. else:
  327. value = []
  328. try:
  329. while True:
  330. value.append(self._decode(f))
  331. except EOFError:
  332. pass
  333. else:
  334. # Normal
  335. if n == 253:
  336. n = strunpack("<Q", f.read(8))[0] # noqa
  337. value = [self._decode(f) for i in range(n)]
  338. elif c == b"m":
  339. value = dict()
  340. n = strunpack("<B", f.read(1))[0]
  341. if n == 253:
  342. n = strunpack("<Q", f.read(8))[0] # noqa
  343. for i in range(n):
  344. n_name = strunpack("<B", f.read(1))[0]
  345. if n_name == 253:
  346. n_name = strunpack("<Q", f.read(8))[0] # noqa
  347. assert n_name > 0
  348. name = f.read(n_name).decode("UTF-8")
  349. value[name] = self._decode(f)
  350. elif c == b"b":
  351. if self._lazy_blob:
  352. value = Blob((f, True))
  353. else:
  354. blob = Blob((f, False))
  355. value = blob.get_bytes()
  356. else:
  357. raise RuntimeError("Parse error %r" % char)
  358. # Convert value if we have an extension for it
  359. if ext_id is not None:
  360. extension = self._extensions.get(ext_id, None)
  361. if extension is not None:
  362. value = extension.decode(self, value)
  363. else:
  364. logger.warning("BSDF warning: no extension found for %r" % ext_id)
  365. return value
  366. def encode(self, ob):
  367. """Save the given object to bytes."""
  368. f = BytesIO()
  369. self.save(f, ob)
  370. return f.getvalue()
  371. def save(self, f, ob):
  372. """Write the given object to the given file object."""
  373. f.write(b"BSDF")
  374. f.write(struct.pack("<B", VERSION[0]))
  375. f.write(struct.pack("<B", VERSION[1]))
  376. # Prepare streaming, this list will have 0 or 1 item at the end
  377. streams = []
  378. self._encode(f, ob, streams, None)
  379. # Verify that stream object was at the end, and add initial elements
  380. if len(streams) > 0:
  381. stream = streams[0]
  382. if stream._start_pos != f.tell():
  383. raise ValueError(
  384. "The stream object must be " "the last object to be encoded."
  385. )
  386. def decode(self, bb):
  387. """Load the data structure that is BSDF-encoded in the given bytes."""
  388. f = BytesIO(bb)
  389. return self.load(f)
  390. def load(self, f):
  391. """Load a BSDF-encoded object from the given file object."""
  392. # Check magic string
  393. f4 = f.read(4)
  394. if f4 != b"BSDF":
  395. raise RuntimeError("This does not look like a BSDF file: %r" % f4)
  396. # Check version
  397. major_version = strunpack("<B", f.read(1))[0]
  398. minor_version = strunpack("<B", f.read(1))[0]
  399. file_version = "%i.%i" % (major_version, minor_version)
  400. if major_version != VERSION[0]: # major version should be 2
  401. t = (
  402. "Reading file with different major version (%s) "
  403. "from the implementation (%s)."
  404. )
  405. raise RuntimeError(t % (__version__, file_version))
  406. if minor_version > VERSION[1]: # minor should be < ours
  407. t = (
  408. "BSDF warning: reading file with higher minor version (%s) "
  409. "than the implementation (%s)."
  410. )
  411. logger.warning(t % (__version__, file_version))
  412. return self._decode(f)
  413. # %% Streaming and blob-files
  414. class BaseStream(object):
  415. """Base class for streams."""
  416. def __init__(self, mode="w"):
  417. self._i = 0
  418. self._count = -1
  419. if isinstance(mode, int):
  420. self._count = mode
  421. mode = "r"
  422. elif mode == "w":
  423. self._count = 0
  424. assert mode in ("r", "w")
  425. self._mode = mode
  426. self._f = None
  427. self._start_pos = 0
  428. def _activate(self, file, encode_func, decode_func):
  429. if self._f is not None: # Associated with another write
  430. raise IOError("Stream object cannot be activated twice?")
  431. self._f = file
  432. self._start_pos = self._f.tell()
  433. self._encode = encode_func
  434. self._decode = decode_func
  435. @property
  436. def mode(self):
  437. """The mode of this stream: 'r' or 'w'."""
  438. return self._mode
  439. class ListStream(BaseStream):
  440. """A streamable list object used for writing or reading.
  441. In read mode, it can also be iterated over.
  442. """
  443. @property
  444. def count(self):
  445. """The number of elements in the stream (can be -1 for unclosed
  446. streams in read-mode).
  447. """
  448. return self._count
  449. @property
  450. def index(self):
  451. """The current index of the element to read/write."""
  452. return self._i
  453. def append(self, item):
  454. """Append an item to the streaming list. The object is immediately
  455. serialized and written to the underlying file.
  456. """
  457. # if self._mode != 'w':
  458. # raise IOError('This ListStream is not in write mode.')
  459. if self._count != self._i:
  460. raise IOError("Can only append items to the end of the stream.")
  461. if self._f is None:
  462. raise IOError("List stream is not associated with a file yet.")
  463. if self._f.closed:
  464. raise IOError("Cannot stream to a close file.")
  465. self._encode(self._f, item, [self], None)
  466. self._i += 1
  467. self._count += 1
  468. def close(self, unstream=False):
  469. """Close the stream, marking the number of written elements. New
  470. elements may still be appended, but they won't be read during decoding.
  471. If ``unstream`` is False, the stream is turned into a regular list
  472. (not streaming).
  473. """
  474. # if self._mode != 'w':
  475. # raise IOError('This ListStream is not in write mode.')
  476. if self._count != self._i:
  477. raise IOError("Can only close when at the end of the stream.")
  478. if self._f is None:
  479. raise IOError("ListStream is not associated with a file yet.")
  480. if self._f.closed:
  481. raise IOError("Cannot close a stream on a close file.")
  482. i = self._f.tell()
  483. self._f.seek(self._start_pos - 8 - 1)
  484. self._f.write(spack("<B", 253 if unstream else 254))
  485. self._f.write(spack("<Q", self._count))
  486. self._f.seek(i)
  487. def next(self):
  488. """Read and return the next element in the streaming list.
  489. Raises StopIteration if the stream is exhausted.
  490. """
  491. if self._mode != "r":
  492. raise IOError("This ListStream in not in read mode.")
  493. if self._f is None:
  494. raise IOError("ListStream is not associated with a file yet.")
  495. if getattr(self._f, "closed", None): # not present on 2.7 http req :/
  496. raise IOError("Cannot read a stream from a close file.")
  497. if self._count >= 0:
  498. if self._i >= self._count:
  499. raise StopIteration()
  500. self._i += 1
  501. return self._decode(self._f)
  502. else:
  503. # This raises EOFError at some point.
  504. try:
  505. res = self._decode(self._f)
  506. self._i += 1
  507. return res
  508. except EOFError:
  509. self._count = self._i
  510. raise StopIteration()
  511. def __iter__(self):
  512. if self._mode != "r":
  513. raise IOError("Cannot iterate: ListStream in not in read mode.")
  514. return self
  515. def __next__(self):
  516. return self.next()
  517. class Blob(object):
  518. """Object to represent a blob of bytes. When used to write a BSDF file,
  519. it's a wrapper for bytes plus properties such as what compression to apply.
  520. When used to read a BSDF file, it can be used to read the data lazily, and
  521. also modify the data if reading in 'r+' mode and the blob isn't compressed.
  522. """
  523. # For now, this does not allow re-sizing blobs (within the allocated size)
  524. # but this can be added later.
  525. def __init__(self, bb, compression=0, extra_size=0, use_checksum=False):
  526. if isinstance(bb, bytes):
  527. self._f = None
  528. self.compressed = self._from_bytes(bb, compression)
  529. self.compression = compression
  530. self.allocated_size = self.used_size + extra_size
  531. self.use_checksum = use_checksum
  532. elif isinstance(bb, tuple) and len(bb) == 2 and hasattr(bb[0], "read"):
  533. self._f, allow_seek = bb
  534. self.compressed = None
  535. self._from_file(self._f, allow_seek)
  536. self._modified = False
  537. else:
  538. raise TypeError("Wrong argument to create Blob.")
  539. def _from_bytes(self, value, compression):
  540. """When used to wrap bytes in a blob."""
  541. if compression == 0:
  542. compressed = value
  543. elif compression == 1:
  544. compressed = zlib.compress(value, 9)
  545. elif compression == 2:
  546. compressed = bz2.compress(value, 9)
  547. else: # pragma: no cover
  548. assert False, "Unknown compression identifier"
  549. self.data_size = len(value)
  550. self.used_size = len(compressed)
  551. return compressed
  552. def _to_file(self, f):
  553. """Private friend method called by encoder to write a blob to a file."""
  554. # Write sizes - write at least in a size that allows resizing
  555. if self.allocated_size <= 250 and self.compression == 0:
  556. f.write(spack("<B", self.allocated_size))
  557. f.write(spack("<B", self.used_size))
  558. f.write(lencode(self.data_size))
  559. else:
  560. f.write(spack("<BQ", 253, self.allocated_size))
  561. f.write(spack("<BQ", 253, self.used_size))
  562. f.write(spack("<BQ", 253, self.data_size))
  563. # Compression and checksum
  564. f.write(spack("B", self.compression))
  565. if self.use_checksum:
  566. f.write(b"\xff" + hashlib.md5(self.compressed).digest())
  567. else:
  568. f.write(b"\x00")
  569. # Byte alignment (only necessary for uncompressed data)
  570. if self.compression == 0:
  571. alignment = 8 - (f.tell() + 1) % 8 # +1 for the byte to write
  572. f.write(spack("<B", alignment)) # padding for byte alignment
  573. f.write(b"\x00" * alignment)
  574. else:
  575. f.write(spack("<B", 0))
  576. # The actual data and extra space
  577. f.write(self.compressed)
  578. f.write(b"\x00" * (self.allocated_size - self.used_size))
  579. def _from_file(self, f, allow_seek):
  580. """Used when a blob is read by the decoder."""
  581. # Read blob header data (5 to 42 bytes)
  582. # Size
  583. allocated_size = strunpack("<B", f.read(1))[0]
  584. if allocated_size == 253:
  585. allocated_size = strunpack("<Q", f.read(8))[0] # noqa
  586. used_size = strunpack("<B", f.read(1))[0]
  587. if used_size == 253:
  588. used_size = strunpack("<Q", f.read(8))[0] # noqa
  589. data_size = strunpack("<B", f.read(1))[0]
  590. if data_size == 253:
  591. data_size = strunpack("<Q", f.read(8))[0] # noqa
  592. # Compression and checksum
  593. compression = strunpack("<B", f.read(1))[0]
  594. has_checksum = strunpack("<B", f.read(1))[0]
  595. if has_checksum:
  596. checksum = f.read(16)
  597. # Skip alignment
  598. alignment = strunpack("<B", f.read(1))[0]
  599. f.read(alignment)
  600. # Get or skip data + extra space
  601. if allow_seek:
  602. self.start_pos = f.tell()
  603. self.end_pos = self.start_pos + used_size
  604. f.seek(self.start_pos + allocated_size)
  605. else:
  606. self.start_pos = None
  607. self.end_pos = None
  608. self.compressed = f.read(used_size)
  609. f.read(allocated_size - used_size)
  610. # Store info
  611. self.alignment = alignment
  612. self.compression = compression
  613. self.use_checksum = checksum if has_checksum else None
  614. self.used_size = used_size
  615. self.allocated_size = allocated_size
  616. self.data_size = data_size
  617. def seek(self, p):
  618. """Seek to the given position (relative to the blob start)."""
  619. if self._f is None:
  620. raise RuntimeError(
  621. "Cannot seek in a blob " "that is not created by the BSDF decoder."
  622. )
  623. if p < 0:
  624. p = self.allocated_size + p
  625. if p < 0 or p > self.allocated_size:
  626. raise IOError("Seek beyond blob boundaries.")
  627. self._f.seek(self.start_pos + p)
  628. def tell(self):
  629. """Get the current file pointer position (relative to the blob start)."""
  630. if self._f is None:
  631. raise RuntimeError(
  632. "Cannot tell in a blob " "that is not created by the BSDF decoder."
  633. )
  634. return self._f.tell() - self.start_pos
  635. def write(self, bb):
  636. """Write bytes to the blob."""
  637. if self._f is None:
  638. raise RuntimeError(
  639. "Cannot write in a blob " "that is not created by the BSDF decoder."
  640. )
  641. if self.compression:
  642. raise IOError("Cannot arbitrarily write in compressed blob.")
  643. if self._f.tell() + len(bb) > self.end_pos:
  644. raise IOError("Write beyond blob boundaries.")
  645. self._modified = True
  646. return self._f.write(bb)
  647. def read(self, n):
  648. """Read n bytes from the blob."""
  649. if self._f is None:
  650. raise RuntimeError(
  651. "Cannot read in a blob " "that is not created by the BSDF decoder."
  652. )
  653. if self.compression:
  654. raise IOError("Cannot arbitrarily read in compressed blob.")
  655. if self._f.tell() + n > self.end_pos:
  656. raise IOError("Read beyond blob boundaries.")
  657. return self._f.read(n)
  658. def get_bytes(self):
  659. """Get the contents of the blob as bytes."""
  660. if self.compressed is not None:
  661. compressed = self.compressed
  662. else:
  663. i = self._f.tell()
  664. self.seek(0)
  665. compressed = self._f.read(self.used_size)
  666. self._f.seek(i)
  667. if self.compression == 0:
  668. value = compressed
  669. elif self.compression == 1:
  670. value = zlib.decompress(compressed)
  671. elif self.compression == 2:
  672. value = bz2.decompress(compressed)
  673. else: # pragma: no cover
  674. raise RuntimeError("Invalid compression %i" % self.compression)
  675. return value
  676. def update_checksum(self):
  677. """Reset the blob's checksum if present. Call this after modifying
  678. the data.
  679. """
  680. # or ... should the presence of a checksum mean that data is proteced?
  681. if self.use_checksum and self._modified:
  682. self.seek(0)
  683. compressed = self._f.read(self.used_size)
  684. self._f.seek(self.start_pos - self.alignment - 1 - 16)
  685. self._f.write(hashlib.md5(compressed).digest())
  686. # %% High-level functions
  687. def encode(ob, extensions=None, **options):
  688. """Save (BSDF-encode) the given object to bytes.
  689. See `BSDFSerializer` for details on extensions and options.
  690. """
  691. s = BsdfSerializer(extensions, **options)
  692. return s.encode(ob)
  693. def save(f, ob, extensions=None, **options):
  694. """Save (BSDF-encode) the given object to the given filename or
  695. file object. See` BSDFSerializer` for details on extensions and options.
  696. """
  697. s = BsdfSerializer(extensions, **options)
  698. if isinstance(f, string_types):
  699. with open(f, "wb") as fp:
  700. return s.save(fp, ob)
  701. else:
  702. return s.save(f, ob)
  703. def decode(bb, extensions=None, **options):
  704. """Load a (BSDF-encoded) structure from bytes.
  705. See `BSDFSerializer` for details on extensions and options.
  706. """
  707. s = BsdfSerializer(extensions, **options)
  708. return s.decode(bb)
  709. def load(f, extensions=None, **options):
  710. """Load a (BSDF-encoded) structure from the given filename or file object.
  711. See `BSDFSerializer` for details on extensions and options.
  712. """
  713. s = BsdfSerializer(extensions, **options)
  714. if isinstance(f, string_types):
  715. if f.startswith(("~/", "~\\")): # pragma: no cover
  716. f = os.path.expanduser(f)
  717. with open(f, "rb") as fp:
  718. return s.load(fp)
  719. else:
  720. return s.load(f)
  721. # Aliases for json compat
  722. loads = decode
  723. dumps = encode
  724. # %% Standard extensions
  725. # Defining extensions as a dict would be more compact and feel lighter, but
  726. # that would only allow lambdas, which is too limiting, e.g. for ndarray
  727. # extension.
  728. class Extension(object):
  729. """Base class to implement BSDF extensions for special data types.
  730. Extension classes are provided to the BSDF serializer, which
  731. instantiates the class. That way, the extension can be somewhat dynamic:
  732. e.g. the NDArrayExtension exposes the ndarray class only when numpy
  733. is imported.
  734. A extension instance must have two attributes. These can be attributes of
  735. the class, or of the instance set in ``__init__()``:
  736. * name (str): the name by which encoded values will be identified.
  737. * cls (type): the type (or list of types) to match values with.
  738. This is optional, but it makes the encoder select extensions faster.
  739. Further, it needs 3 methods:
  740. * `match(serializer, value) -> bool`: return whether the extension can
  741. convert the given value. The default is ``isinstance(value, self.cls)``.
  742. * `encode(serializer, value) -> encoded_value`: the function to encode a
  743. value to more basic data types.
  744. * `decode(serializer, encoded_value) -> value`: the function to decode an
  745. encoded value back to its intended representation.
  746. """
  747. name = ""
  748. cls = ()
  749. def __repr__(self):
  750. return "<BSDF extension %r at 0x%s>" % (self.name, hex(id(self)))
  751. def match(self, s, v):
  752. return isinstance(v, self.cls)
  753. def encode(self, s, v):
  754. raise NotImplementedError()
  755. def decode(self, s, v):
  756. raise NotImplementedError()
  757. class ComplexExtension(Extension):
  758. name = "c"
  759. cls = complex
  760. def encode(self, s, v):
  761. return (v.real, v.imag)
  762. def decode(self, s, v):
  763. return complex(v[0], v[1])
  764. class NDArrayExtension(Extension):
  765. name = "ndarray"
  766. def __init__(self):
  767. if "numpy" in sys.modules:
  768. import numpy as np
  769. self.cls = np.ndarray
  770. def match(self, s, v): # pragma: no cover - e.g. work for nd arrays in JS
  771. return hasattr(v, "shape") and hasattr(v, "dtype") and hasattr(v, "tobytes")
  772. def encode(self, s, v):
  773. return dict(shape=v.shape, dtype=text_type(v.dtype), data=v.tobytes())
  774. def decode(self, s, v):
  775. try:
  776. import numpy as np
  777. except ImportError: # pragma: no cover
  778. return v
  779. a = np.frombuffer(v["data"], dtype=v["dtype"])
  780. a.shape = v["shape"]
  781. return a
  782. standard_extensions = [ComplexExtension, NDArrayExtension]
  783. if __name__ == "__main__":
  784. # Invoke CLI
  785. import bsdf_cli
  786. bsdf_cli.main()