| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915 |
- #!/usr/bin/env python
- # This file is distributed under the terms of the 2-clause BSD License.
- # Copyright (c) 2017-2018, Almar Klein
- """
- Python implementation of the Binary Structured Data Format (BSDF).
- BSDF is a binary format for serializing structured (scientific) data.
- See http://bsdf.io for more information.
- This is the reference implementation, which is relatively relatively
- sophisticated, providing e.g. lazy loading of blobs and streamed
- reading/writing. A simpler Python implementation is available as
- ``bsdf_lite.py``.
- This module has no dependencies and works on Python 2.7 and 3.4+.
- Note: on Legacy Python (Python 2.7), non-Unicode strings are encoded as bytes.
- """
- # todo: in 2020, remove six stuff, __future__ and _isidentifier
- # todo: in 2020, remove 'utf-8' args to encode/decode; it's faster
- from __future__ import absolute_import, division, print_function
- import bz2
- import hashlib
- import logging
- import os
- import re
- import struct
- import sys
- import types
- import zlib
- from io import BytesIO
- logger = logging.getLogger(__name__)
- # Notes on versioning: the major and minor numbers correspond to the
- # BSDF format version. The major number if increased when backward
- # incompatible changes are introduced. An implementation must raise an
- # exception when the file being read has a higher major version. The
- # minor number is increased when new backward compatible features are
- # introduced. An implementation must display a warning when the file
- # being read has a higher minor version. The patch version is increased
- # for subsequent releases of the implementation.
- VERSION = 2, 1, 2
- __version__ = ".".join(str(i) for i in VERSION)
- # %% The encoder and decoder implementation
- # From six.py
- PY3 = sys.version_info[0] >= 3
- if PY3:
- text_type = str
- string_types = str
- unicode_types = str
- integer_types = int
- classtypes = type
- else: # pragma: no cover
- logging.basicConfig() # avoid "no handlers found" error
- text_type = unicode # noqa
- string_types = basestring # noqa
- unicode_types = unicode # noqa
- integer_types = (int, long) # noqa
- classtypes = type, types.ClassType
- # Shorthands
- spack = struct.pack
- strunpack = struct.unpack
- def lencode(x):
- """Encode an unsigned integer into a variable sized blob of bytes."""
- # We could support 16 bit and 32 bit as well, but the gain is low, since
- # 9 bytes for collections with over 250 elements is marginal anyway.
- if x <= 250:
- return spack("<B", x)
- # elif x < 65536:
- # return spack('<BH', 251, x)
- # elif x < 4294967296:
- # return spack('<BI', 252, x)
- else:
- return spack("<BQ", 253, x)
- # Include len decoder for completeness; we've inlined it for performance.
- def lendecode(f):
- """Decode an unsigned integer from a file."""
- n = strunpack("<B", f.read(1))[0]
- if n == 253:
- n = strunpack("<Q", f.read(8))[0] # noqa
- return n
- def encode_type_id(b, ext_id):
- """Encode the type identifier, with or without extension id."""
- if ext_id is not None:
- bb = ext_id.encode("UTF-8")
- return b.upper() + lencode(len(bb)) + bb # noqa
- else:
- return b # noqa
- def _isidentifier(s): # pragma: no cover
- """Use of str.isidentifier() for Legacy Python, but slower."""
- # http://stackoverflow.com/questions/2544972/
- return (
- isinstance(s, string_types)
- and re.match(r"^\w+$", s, re.UNICODE)
- and re.match(r"^[0-9]", s) is None
- )
- class BsdfSerializer(object):
- """Instances of this class represent a BSDF encoder/decoder.
- It acts as a placeholder for a set of extensions and encoding/decoding
- options. Use this to predefine extensions and options for high
- performance encoding/decoding. For general use, see the functions
- `save()`, `encode()`, `load()`, and `decode()`.
- This implementation of BSDF supports streaming lists (keep adding
- to a list after writing the main file), lazy loading of blobs, and
- in-place editing of blobs (for streams opened with a+).
- Options for encoding:
- * compression (int or str): ``0`` or "no" for no compression (default),
- ``1`` or "zlib" for Zlib compression (same as zip files and PNG), and
- ``2`` or "bz2" for Bz2 compression (more compact but slower writing).
- Note that some BSDF implementations (e.g. JavaScript) may not support
- compression.
- * use_checksum (bool): whether to include a checksum with binary blobs.
- * float64 (bool): Whether to write floats as 64 bit (default) or 32 bit.
- Options for decoding:
- * load_streaming (bool): if True, and the final object in the structure was
- a stream, will make it available as a stream in the decoded object.
- * lazy_blob (bool): if True, bytes are represented as Blob objects that can
- be used to lazily access the data, and also overwrite the data if the
- file is open in a+ mode.
- """
- def __init__(self, extensions=None, **options):
- self._extensions = {} # name -> extension
- self._extensions_by_cls = {} # cls -> (name, extension.encode)
- if extensions is None:
- extensions = standard_extensions
- for extension in extensions:
- self.add_extension(extension)
- self._parse_options(**options)
- def _parse_options(
- self,
- compression=0,
- use_checksum=False,
- float64=True,
- load_streaming=False,
- lazy_blob=False,
- ):
- # Validate compression
- if isinstance(compression, string_types):
- m = {"no": 0, "zlib": 1, "bz2": 2}
- compression = m.get(compression.lower(), compression)
- if compression not in (0, 1, 2):
- raise TypeError("Compression must be 0, 1, 2, " '"no", "zlib", or "bz2"')
- self._compression = compression
- # Other encoding args
- self._use_checksum = bool(use_checksum)
- self._float64 = bool(float64)
- # Decoding args
- self._load_streaming = bool(load_streaming)
- self._lazy_blob = bool(lazy_blob)
- def add_extension(self, extension_class):
- """Add an extension to this serializer instance, which must be
- a subclass of Extension. Can be used as a decorator.
- """
- # Check class
- if not (
- isinstance(extension_class, type) and issubclass(extension_class, Extension)
- ):
- raise TypeError("add_extension() expects a Extension class.")
- extension = extension_class()
- # Get name
- name = extension.name
- if not isinstance(name, str):
- raise TypeError("Extension name must be str.")
- if len(name) == 0 or len(name) > 250:
- raise NameError(
- "Extension names must be nonempty and shorter " "than 251 chars."
- )
- if name in self._extensions:
- logger.warning(
- 'BSDF warning: overwriting extension "%s", '
- "consider removing first" % name
- )
- # Get classes
- cls = extension.cls
- if not cls:
- clss = []
- elif isinstance(cls, (tuple, list)):
- clss = cls
- else:
- clss = [cls]
- for cls in clss:
- if not isinstance(cls, classtypes):
- raise TypeError("Extension classes must be types.")
- # Store
- for cls in clss:
- self._extensions_by_cls[cls] = name, extension.encode
- self._extensions[name] = extension
- return extension_class
- def remove_extension(self, name):
- """Remove a converted by its unique name."""
- if not isinstance(name, str):
- raise TypeError("Extension name must be str.")
- if name in self._extensions:
- self._extensions.pop(name)
- for cls in list(self._extensions_by_cls.keys()):
- if self._extensions_by_cls[cls][0] == name:
- self._extensions_by_cls.pop(cls)
- def _encode(self, f, value, streams, ext_id):
- """Main encoder function."""
- x = encode_type_id
- if value is None:
- f.write(x(b"v", ext_id)) # V for void
- elif value is True:
- f.write(x(b"y", ext_id)) # Y for yes
- elif value is False:
- f.write(x(b"n", ext_id)) # N for no
- elif isinstance(value, integer_types):
- if -32768 <= value <= 32767:
- f.write(x(b"h", ext_id) + spack("h", value)) # H for ...
- else:
- f.write(x(b"i", ext_id) + spack("<q", value)) # I for int
- elif isinstance(value, float):
- if self._float64:
- f.write(x(b"d", ext_id) + spack("<d", value)) # D for double
- else:
- f.write(x(b"f", ext_id) + spack("<f", value)) # f for float
- elif isinstance(value, unicode_types):
- bb = value.encode("UTF-8")
- f.write(x(b"s", ext_id) + lencode(len(bb))) # S for str
- f.write(bb)
- elif isinstance(value, (list, tuple)):
- f.write(x(b"l", ext_id) + lencode(len(value))) # L for list
- for v in value:
- self._encode(f, v, streams, None)
- elif isinstance(value, dict):
- f.write(x(b"m", ext_id) + lencode(len(value))) # M for mapping
- for key, v in value.items():
- if PY3:
- assert key.isidentifier() # faster
- else: # pragma: no cover
- assert _isidentifier(key)
- # yield ' ' * indent + key
- name_b = key.encode("UTF-8")
- f.write(lencode(len(name_b)))
- f.write(name_b)
- self._encode(f, v, streams, None)
- elif isinstance(value, bytes):
- f.write(x(b"b", ext_id)) # B for blob
- blob = Blob(
- value, compression=self._compression, use_checksum=self._use_checksum
- )
- blob._to_file(f) # noqa
- elif isinstance(value, Blob):
- f.write(x(b"b", ext_id)) # B for blob
- value._to_file(f) # noqa
- elif isinstance(value, BaseStream):
- # Initialize the stream
- if value.mode != "w":
- raise ValueError("Cannot serialize a read-mode stream.")
- elif isinstance(value, ListStream):
- f.write(x(b"l", ext_id) + spack("<BQ", 255, 0)) # L for list
- else:
- raise TypeError("Only ListStream is supported")
- # Mark this as *the* stream, and activate the stream.
- # The save() function verifies this is the last written object.
- if len(streams) > 0:
- raise ValueError("Can only have one stream per file.")
- streams.append(value)
- value._activate(f, self._encode, self._decode) # noqa
- else:
- if ext_id is not None:
- raise ValueError(
- "Extension %s wronfully encodes object to another "
- "extension object (though it may encode to a list/dict "
- "that contains other extension objects)." % ext_id
- )
- # Try if the value is of a type we know
- ex = self._extensions_by_cls.get(value.__class__, None)
- # Maybe its a subclass of a type we know
- if ex is None:
- for name, c in self._extensions.items():
- if c.match(self, value):
- ex = name, c.encode
- break
- else:
- ex = None
- # Success or fail
- if ex is not None:
- ext_id2, extension_encode = ex
- self._encode(f, extension_encode(self, value), streams, ext_id2)
- else:
- t = (
- "Class %r is not a valid base BSDF type, nor is it "
- "handled by an extension."
- )
- raise TypeError(t % value.__class__.__name__)
- def _decode(self, f):
- """Main decoder function."""
- # Get value
- char = f.read(1)
- c = char.lower()
- # Conversion (uppercase value identifiers signify converted values)
- if not char:
- raise EOFError()
- elif char != c:
- n = strunpack("<B", f.read(1))[0]
- # if n == 253: n = strunpack('<Q', f.read(8))[0] # noqa - noneed
- ext_id = f.read(n).decode("UTF-8")
- else:
- ext_id = None
- if c == b"v":
- value = None
- elif c == b"y":
- value = True
- elif c == b"n":
- value = False
- elif c == b"h":
- value = strunpack("<h", f.read(2))[0]
- elif c == b"i":
- value = strunpack("<q", f.read(8))[0]
- elif c == b"f":
- value = strunpack("<f", f.read(4))[0]
- elif c == b"d":
- value = strunpack("<d", f.read(8))[0]
- elif c == b"s":
- n_s = strunpack("<B", f.read(1))[0]
- if n_s == 253:
- n_s = strunpack("<Q", f.read(8))[0] # noqa
- value = f.read(n_s).decode("UTF-8")
- elif c == b"l":
- n = strunpack("<B", f.read(1))[0]
- if n >= 254:
- # Streaming
- closed = n == 254
- n = strunpack("<Q", f.read(8))[0]
- if self._load_streaming:
- value = ListStream(n if closed else "r")
- value._activate(f, self._encode, self._decode) # noqa
- elif closed:
- value = [self._decode(f) for i in range(n)]
- else:
- value = []
- try:
- while True:
- value.append(self._decode(f))
- except EOFError:
- pass
- else:
- # Normal
- if n == 253:
- n = strunpack("<Q", f.read(8))[0] # noqa
- value = [self._decode(f) for i in range(n)]
- elif c == b"m":
- value = dict()
- n = strunpack("<B", f.read(1))[0]
- if n == 253:
- n = strunpack("<Q", f.read(8))[0] # noqa
- for i in range(n):
- n_name = strunpack("<B", f.read(1))[0]
- if n_name == 253:
- n_name = strunpack("<Q", f.read(8))[0] # noqa
- assert n_name > 0
- name = f.read(n_name).decode("UTF-8")
- value[name] = self._decode(f)
- elif c == b"b":
- if self._lazy_blob:
- value = Blob((f, True))
- else:
- blob = Blob((f, False))
- value = blob.get_bytes()
- else:
- raise RuntimeError("Parse error %r" % char)
- # Convert value if we have an extension for it
- if ext_id is not None:
- extension = self._extensions.get(ext_id, None)
- if extension is not None:
- value = extension.decode(self, value)
- else:
- logger.warning("BSDF warning: no extension found for %r" % ext_id)
- return value
- def encode(self, ob):
- """Save the given object to bytes."""
- f = BytesIO()
- self.save(f, ob)
- return f.getvalue()
- def save(self, f, ob):
- """Write the given object to the given file object."""
- f.write(b"BSDF")
- f.write(struct.pack("<B", VERSION[0]))
- f.write(struct.pack("<B", VERSION[1]))
- # Prepare streaming, this list will have 0 or 1 item at the end
- streams = []
- self._encode(f, ob, streams, None)
- # Verify that stream object was at the end, and add initial elements
- if len(streams) > 0:
- stream = streams[0]
- if stream._start_pos != f.tell():
- raise ValueError(
- "The stream object must be " "the last object to be encoded."
- )
- def decode(self, bb):
- """Load the data structure that is BSDF-encoded in the given bytes."""
- f = BytesIO(bb)
- return self.load(f)
- def load(self, f):
- """Load a BSDF-encoded object from the given file object."""
- # Check magic string
- f4 = f.read(4)
- if f4 != b"BSDF":
- raise RuntimeError("This does not look like a BSDF file: %r" % f4)
- # Check version
- major_version = strunpack("<B", f.read(1))[0]
- minor_version = strunpack("<B", f.read(1))[0]
- file_version = "%i.%i" % (major_version, minor_version)
- if major_version != VERSION[0]: # major version should be 2
- t = (
- "Reading file with different major version (%s) "
- "from the implementation (%s)."
- )
- raise RuntimeError(t % (__version__, file_version))
- if minor_version > VERSION[1]: # minor should be < ours
- t = (
- "BSDF warning: reading file with higher minor version (%s) "
- "than the implementation (%s)."
- )
- logger.warning(t % (__version__, file_version))
- return self._decode(f)
- # %% Streaming and blob-files
- class BaseStream(object):
- """Base class for streams."""
- def __init__(self, mode="w"):
- self._i = 0
- self._count = -1
- if isinstance(mode, int):
- self._count = mode
- mode = "r"
- elif mode == "w":
- self._count = 0
- assert mode in ("r", "w")
- self._mode = mode
- self._f = None
- self._start_pos = 0
- def _activate(self, file, encode_func, decode_func):
- if self._f is not None: # Associated with another write
- raise IOError("Stream object cannot be activated twice?")
- self._f = file
- self._start_pos = self._f.tell()
- self._encode = encode_func
- self._decode = decode_func
- @property
- def mode(self):
- """The mode of this stream: 'r' or 'w'."""
- return self._mode
- class ListStream(BaseStream):
- """A streamable list object used for writing or reading.
- In read mode, it can also be iterated over.
- """
- @property
- def count(self):
- """The number of elements in the stream (can be -1 for unclosed
- streams in read-mode).
- """
- return self._count
- @property
- def index(self):
- """The current index of the element to read/write."""
- return self._i
- def append(self, item):
- """Append an item to the streaming list. The object is immediately
- serialized and written to the underlying file.
- """
- # if self._mode != 'w':
- # raise IOError('This ListStream is not in write mode.')
- if self._count != self._i:
- raise IOError("Can only append items to the end of the stream.")
- if self._f is None:
- raise IOError("List stream is not associated with a file yet.")
- if self._f.closed:
- raise IOError("Cannot stream to a close file.")
- self._encode(self._f, item, [self], None)
- self._i += 1
- self._count += 1
- def close(self, unstream=False):
- """Close the stream, marking the number of written elements. New
- elements may still be appended, but they won't be read during decoding.
- If ``unstream`` is False, the stream is turned into a regular list
- (not streaming).
- """
- # if self._mode != 'w':
- # raise IOError('This ListStream is not in write mode.')
- if self._count != self._i:
- raise IOError("Can only close when at the end of the stream.")
- if self._f is None:
- raise IOError("ListStream is not associated with a file yet.")
- if self._f.closed:
- raise IOError("Cannot close a stream on a close file.")
- i = self._f.tell()
- self._f.seek(self._start_pos - 8 - 1)
- self._f.write(spack("<B", 253 if unstream else 254))
- self._f.write(spack("<Q", self._count))
- self._f.seek(i)
- def next(self):
- """Read and return the next element in the streaming list.
- Raises StopIteration if the stream is exhausted.
- """
- if self._mode != "r":
- raise IOError("This ListStream in not in read mode.")
- if self._f is None:
- raise IOError("ListStream is not associated with a file yet.")
- if getattr(self._f, "closed", None): # not present on 2.7 http req :/
- raise IOError("Cannot read a stream from a close file.")
- if self._count >= 0:
- if self._i >= self._count:
- raise StopIteration()
- self._i += 1
- return self._decode(self._f)
- else:
- # This raises EOFError at some point.
- try:
- res = self._decode(self._f)
- self._i += 1
- return res
- except EOFError:
- self._count = self._i
- raise StopIteration()
- def __iter__(self):
- if self._mode != "r":
- raise IOError("Cannot iterate: ListStream in not in read mode.")
- return self
- def __next__(self):
- return self.next()
- class Blob(object):
- """Object to represent a blob of bytes. When used to write a BSDF file,
- it's a wrapper for bytes plus properties such as what compression to apply.
- When used to read a BSDF file, it can be used to read the data lazily, and
- also modify the data if reading in 'r+' mode and the blob isn't compressed.
- """
- # For now, this does not allow re-sizing blobs (within the allocated size)
- # but this can be added later.
- def __init__(self, bb, compression=0, extra_size=0, use_checksum=False):
- if isinstance(bb, bytes):
- self._f = None
- self.compressed = self._from_bytes(bb, compression)
- self.compression = compression
- self.allocated_size = self.used_size + extra_size
- self.use_checksum = use_checksum
- elif isinstance(bb, tuple) and len(bb) == 2 and hasattr(bb[0], "read"):
- self._f, allow_seek = bb
- self.compressed = None
- self._from_file(self._f, allow_seek)
- self._modified = False
- else:
- raise TypeError("Wrong argument to create Blob.")
- def _from_bytes(self, value, compression):
- """When used to wrap bytes in a blob."""
- if compression == 0:
- compressed = value
- elif compression == 1:
- compressed = zlib.compress(value, 9)
- elif compression == 2:
- compressed = bz2.compress(value, 9)
- else: # pragma: no cover
- assert False, "Unknown compression identifier"
- self.data_size = len(value)
- self.used_size = len(compressed)
- return compressed
- def _to_file(self, f):
- """Private friend method called by encoder to write a blob to a file."""
- # Write sizes - write at least in a size that allows resizing
- if self.allocated_size <= 250 and self.compression == 0:
- f.write(spack("<B", self.allocated_size))
- f.write(spack("<B", self.used_size))
- f.write(lencode(self.data_size))
- else:
- f.write(spack("<BQ", 253, self.allocated_size))
- f.write(spack("<BQ", 253, self.used_size))
- f.write(spack("<BQ", 253, self.data_size))
- # Compression and checksum
- f.write(spack("B", self.compression))
- if self.use_checksum:
- f.write(b"\xff" + hashlib.md5(self.compressed).digest())
- else:
- f.write(b"\x00")
- # Byte alignment (only necessary for uncompressed data)
- if self.compression == 0:
- alignment = 8 - (f.tell() + 1) % 8 # +1 for the byte to write
- f.write(spack("<B", alignment)) # padding for byte alignment
- f.write(b"\x00" * alignment)
- else:
- f.write(spack("<B", 0))
- # The actual data and extra space
- f.write(self.compressed)
- f.write(b"\x00" * (self.allocated_size - self.used_size))
- def _from_file(self, f, allow_seek):
- """Used when a blob is read by the decoder."""
- # Read blob header data (5 to 42 bytes)
- # Size
- allocated_size = strunpack("<B", f.read(1))[0]
- if allocated_size == 253:
- allocated_size = strunpack("<Q", f.read(8))[0] # noqa
- used_size = strunpack("<B", f.read(1))[0]
- if used_size == 253:
- used_size = strunpack("<Q", f.read(8))[0] # noqa
- data_size = strunpack("<B", f.read(1))[0]
- if data_size == 253:
- data_size = strunpack("<Q", f.read(8))[0] # noqa
- # Compression and checksum
- compression = strunpack("<B", f.read(1))[0]
- has_checksum = strunpack("<B", f.read(1))[0]
- if has_checksum:
- checksum = f.read(16)
- # Skip alignment
- alignment = strunpack("<B", f.read(1))[0]
- f.read(alignment)
- # Get or skip data + extra space
- if allow_seek:
- self.start_pos = f.tell()
- self.end_pos = self.start_pos + used_size
- f.seek(self.start_pos + allocated_size)
- else:
- self.start_pos = None
- self.end_pos = None
- self.compressed = f.read(used_size)
- f.read(allocated_size - used_size)
- # Store info
- self.alignment = alignment
- self.compression = compression
- self.use_checksum = checksum if has_checksum else None
- self.used_size = used_size
- self.allocated_size = allocated_size
- self.data_size = data_size
- def seek(self, p):
- """Seek to the given position (relative to the blob start)."""
- if self._f is None:
- raise RuntimeError(
- "Cannot seek in a blob " "that is not created by the BSDF decoder."
- )
- if p < 0:
- p = self.allocated_size + p
- if p < 0 or p > self.allocated_size:
- raise IOError("Seek beyond blob boundaries.")
- self._f.seek(self.start_pos + p)
- def tell(self):
- """Get the current file pointer position (relative to the blob start)."""
- if self._f is None:
- raise RuntimeError(
- "Cannot tell in a blob " "that is not created by the BSDF decoder."
- )
- return self._f.tell() - self.start_pos
- def write(self, bb):
- """Write bytes to the blob."""
- if self._f is None:
- raise RuntimeError(
- "Cannot write in a blob " "that is not created by the BSDF decoder."
- )
- if self.compression:
- raise IOError("Cannot arbitrarily write in compressed blob.")
- if self._f.tell() + len(bb) > self.end_pos:
- raise IOError("Write beyond blob boundaries.")
- self._modified = True
- return self._f.write(bb)
- def read(self, n):
- """Read n bytes from the blob."""
- if self._f is None:
- raise RuntimeError(
- "Cannot read in a blob " "that is not created by the BSDF decoder."
- )
- if self.compression:
- raise IOError("Cannot arbitrarily read in compressed blob.")
- if self._f.tell() + n > self.end_pos:
- raise IOError("Read beyond blob boundaries.")
- return self._f.read(n)
- def get_bytes(self):
- """Get the contents of the blob as bytes."""
- if self.compressed is not None:
- compressed = self.compressed
- else:
- i = self._f.tell()
- self.seek(0)
- compressed = self._f.read(self.used_size)
- self._f.seek(i)
- if self.compression == 0:
- value = compressed
- elif self.compression == 1:
- value = zlib.decompress(compressed)
- elif self.compression == 2:
- value = bz2.decompress(compressed)
- else: # pragma: no cover
- raise RuntimeError("Invalid compression %i" % self.compression)
- return value
- def update_checksum(self):
- """Reset the blob's checksum if present. Call this after modifying
- the data.
- """
- # or ... should the presence of a checksum mean that data is proteced?
- if self.use_checksum and self._modified:
- self.seek(0)
- compressed = self._f.read(self.used_size)
- self._f.seek(self.start_pos - self.alignment - 1 - 16)
- self._f.write(hashlib.md5(compressed).digest())
- # %% High-level functions
- def encode(ob, extensions=None, **options):
- """Save (BSDF-encode) the given object to bytes.
- See `BSDFSerializer` for details on extensions and options.
- """
- s = BsdfSerializer(extensions, **options)
- return s.encode(ob)
- def save(f, ob, extensions=None, **options):
- """Save (BSDF-encode) the given object to the given filename or
- file object. See` BSDFSerializer` for details on extensions and options.
- """
- s = BsdfSerializer(extensions, **options)
- if isinstance(f, string_types):
- with open(f, "wb") as fp:
- return s.save(fp, ob)
- else:
- return s.save(f, ob)
- def decode(bb, extensions=None, **options):
- """Load a (BSDF-encoded) structure from bytes.
- See `BSDFSerializer` for details on extensions and options.
- """
- s = BsdfSerializer(extensions, **options)
- return s.decode(bb)
- def load(f, extensions=None, **options):
- """Load a (BSDF-encoded) structure from the given filename or file object.
- See `BSDFSerializer` for details on extensions and options.
- """
- s = BsdfSerializer(extensions, **options)
- if isinstance(f, string_types):
- if f.startswith(("~/", "~\\")): # pragma: no cover
- f = os.path.expanduser(f)
- with open(f, "rb") as fp:
- return s.load(fp)
- else:
- return s.load(f)
- # Aliases for json compat
- loads = decode
- dumps = encode
- # %% Standard extensions
- # Defining extensions as a dict would be more compact and feel lighter, but
- # that would only allow lambdas, which is too limiting, e.g. for ndarray
- # extension.
- class Extension(object):
- """Base class to implement BSDF extensions for special data types.
- Extension classes are provided to the BSDF serializer, which
- instantiates the class. That way, the extension can be somewhat dynamic:
- e.g. the NDArrayExtension exposes the ndarray class only when numpy
- is imported.
- A extension instance must have two attributes. These can be attributes of
- the class, or of the instance set in ``__init__()``:
- * name (str): the name by which encoded values will be identified.
- * cls (type): the type (or list of types) to match values with.
- This is optional, but it makes the encoder select extensions faster.
- Further, it needs 3 methods:
- * `match(serializer, value) -> bool`: return whether the extension can
- convert the given value. The default is ``isinstance(value, self.cls)``.
- * `encode(serializer, value) -> encoded_value`: the function to encode a
- value to more basic data types.
- * `decode(serializer, encoded_value) -> value`: the function to decode an
- encoded value back to its intended representation.
- """
- name = ""
- cls = ()
- def __repr__(self):
- return "<BSDF extension %r at 0x%s>" % (self.name, hex(id(self)))
- def match(self, s, v):
- return isinstance(v, self.cls)
- def encode(self, s, v):
- raise NotImplementedError()
- def decode(self, s, v):
- raise NotImplementedError()
- class ComplexExtension(Extension):
- name = "c"
- cls = complex
- def encode(self, s, v):
- return (v.real, v.imag)
- def decode(self, s, v):
- return complex(v[0], v[1])
- class NDArrayExtension(Extension):
- name = "ndarray"
- def __init__(self):
- if "numpy" in sys.modules:
- import numpy as np
- self.cls = np.ndarray
- def match(self, s, v): # pragma: no cover - e.g. work for nd arrays in JS
- return hasattr(v, "shape") and hasattr(v, "dtype") and hasattr(v, "tobytes")
- def encode(self, s, v):
- return dict(shape=v.shape, dtype=text_type(v.dtype), data=v.tobytes())
- def decode(self, s, v):
- try:
- import numpy as np
- except ImportError: # pragma: no cover
- return v
- a = np.frombuffer(v["data"], dtype=v["dtype"])
- a.shape = v["shape"]
- return a
- standard_extensions = [ComplexExtension, NDArrayExtension]
- if __name__ == "__main__":
- # Invoke CLI
- import bsdf_cli
- bsdf_cli.main()
|