caching.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. from __future__ import annotations
  2. import collections
  3. import functools
  4. import logging
  5. import math
  6. import os
  7. import threading
  8. from collections import OrderedDict
  9. from collections.abc import Callable
  10. from concurrent.futures import Future, ThreadPoolExecutor
  11. from itertools import groupby
  12. from operator import itemgetter
  13. from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, TypeVar
  14. if TYPE_CHECKING:
  15. import mmap
  16. from typing_extensions import ParamSpec
  17. P = ParamSpec("P")
  18. else:
  19. P = TypeVar("P")
  20. T = TypeVar("T")
  21. logger = logging.getLogger("fsspec")
  22. Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
  23. MultiFetcher = Callable[[list[int, int]], bytes] # Maps [(start, end)] to bytes
  24. class BaseCache:
  25. """Pass-though cache: doesn't keep anything, calls every time
  26. Acts as base class for other cachers
  27. Parameters
  28. ----------
  29. blocksize: int
  30. How far to read ahead in numbers of bytes
  31. fetcher: func
  32. Function of the form f(start, end) which gets bytes from remote as
  33. specified
  34. size: int
  35. How big this file is
  36. """
  37. name: ClassVar[str] = "none"
  38. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  39. self.blocksize = blocksize
  40. self.nblocks = 0
  41. self.fetcher = fetcher
  42. self.size = size
  43. self.hit_count = 0
  44. self.miss_count = 0
  45. # the bytes that we actually requested
  46. self.total_requested_bytes = 0
  47. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  48. if start is None:
  49. start = 0
  50. if stop is None:
  51. stop = self.size
  52. if start >= self.size or start >= stop:
  53. return b""
  54. return self.fetcher(start, stop)
  55. def _reset_stats(self) -> None:
  56. """Reset hit and miss counts for a more ganular report e.g. by file."""
  57. self.hit_count = 0
  58. self.miss_count = 0
  59. self.total_requested_bytes = 0
  60. def _log_stats(self) -> str:
  61. """Return a formatted string of the cache statistics."""
  62. if self.hit_count == 0 and self.miss_count == 0:
  63. # a cache that does nothing, this is for logs only
  64. return ""
  65. return f" , {self.name}: {self.hit_count} hits, {self.miss_count} misses, {self.total_requested_bytes} total requested bytes"
  66. def __repr__(self) -> str:
  67. # TODO: use rich for better formatting
  68. return f"""
  69. <{self.__class__.__name__}:
  70. block size : {self.blocksize}
  71. block count : {self.nblocks}
  72. file size : {self.size}
  73. cache hits : {self.hit_count}
  74. cache misses: {self.miss_count}
  75. total requested bytes: {self.total_requested_bytes}>
  76. """
  77. class MMapCache(BaseCache):
  78. """memory-mapped sparse file cache
  79. Opens temporary file, which is filled blocks-wise when data is requested.
  80. Ensure there is enough disc space in the temporary location.
  81. This cache method might only work on posix
  82. Parameters
  83. ----------
  84. blocksize: int
  85. How far to read ahead in numbers of bytes
  86. fetcher: Fetcher
  87. Function of the form f(start, end) which gets bytes from remote as
  88. specified
  89. size: int
  90. How big this file is
  91. location: str
  92. Where to create the temporary file. If None, a temporary file is
  93. created using tempfile.TemporaryFile().
  94. blocks: set[int]
  95. Set of block numbers that have already been fetched. If None, an empty
  96. set is created.
  97. multi_fetcher: MultiFetcher
  98. Function of the form f([(start, end)]) which gets bytes from remote
  99. as specified. This function is used to fetch multiple blocks at once.
  100. If not specified, the fetcher function is used instead.
  101. """
  102. name = "mmap"
  103. def __init__(
  104. self,
  105. blocksize: int,
  106. fetcher: Fetcher,
  107. size: int,
  108. location: str | None = None,
  109. blocks: set[int] | None = None,
  110. multi_fetcher: MultiFetcher | None = None,
  111. ) -> None:
  112. super().__init__(blocksize, fetcher, size)
  113. self.blocks = set() if blocks is None else blocks
  114. self.location = location
  115. self.multi_fetcher = multi_fetcher
  116. self.cache = self._makefile()
  117. def _makefile(self) -> mmap.mmap | bytearray:
  118. import mmap
  119. import tempfile
  120. if self.size == 0:
  121. return bytearray()
  122. # posix version
  123. if self.location is None or not os.path.exists(self.location):
  124. if self.location is None:
  125. fd = tempfile.TemporaryFile()
  126. self.blocks = set()
  127. else:
  128. fd = open(self.location, "wb+")
  129. fd.seek(self.size - 1)
  130. fd.write(b"1")
  131. fd.flush()
  132. else:
  133. fd = open(self.location, "r+b")
  134. return mmap.mmap(fd.fileno(), self.size)
  135. def _fetch(self, start: int | None, end: int | None) -> bytes:
  136. logger.debug(f"MMap cache fetching {start}-{end}")
  137. if start is None:
  138. start = 0
  139. if end is None:
  140. end = self.size
  141. if start >= self.size or start >= end:
  142. return b""
  143. start_block = start // self.blocksize
  144. end_block = end // self.blocksize
  145. block_range = range(start_block, end_block + 1)
  146. # Determine which blocks need to be fetched. This sequence is sorted by construction.
  147. need = (i for i in block_range if i not in self.blocks)
  148. # Count the number of blocks already cached
  149. self.hit_count += sum(1 for i in block_range if i in self.blocks)
  150. ranges = []
  151. # Consolidate needed blocks.
  152. # Algorithm adapted from Python 2.x itertools documentation.
  153. # We are grouping an enumerated sequence of blocks. By comparing when the difference
  154. # between an ascending range (provided by enumerate) and the needed block numbers
  155. # we can detect when the block number skips values. The key computes this difference.
  156. # Whenever the difference changes, we know that we have previously cached block(s),
  157. # and a new group is started. In other words, this algorithm neatly groups
  158. # runs of consecutive block numbers so they can be fetched together.
  159. for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]):
  160. # Extract the blocks from the enumerated sequence
  161. _blocks = tuple(map(itemgetter(1), _blocks))
  162. # Compute start of first block
  163. sstart = _blocks[0] * self.blocksize
  164. # Compute the end of the last block. Last block may not be full size.
  165. send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size)
  166. # Fetch bytes (could be multiple consecutive blocks)
  167. self.total_requested_bytes += send - sstart
  168. logger.debug(
  169. f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
  170. )
  171. ranges.append((sstart, send))
  172. # Update set of cached blocks
  173. self.blocks.update(_blocks)
  174. # Update cache statistics with number of blocks we had to cache
  175. self.miss_count += len(_blocks)
  176. if not ranges:
  177. return self.cache[start:end]
  178. if self.multi_fetcher:
  179. logger.debug(f"MMap get blocks {ranges}")
  180. for idx, r in enumerate(self.multi_fetcher(ranges)):
  181. (sstart, send) = ranges[idx]
  182. logger.debug(f"MMap copy block ({sstart}-{send}")
  183. self.cache[sstart:send] = r
  184. else:
  185. for sstart, send in ranges:
  186. logger.debug(f"MMap get block ({sstart}-{send}")
  187. self.cache[sstart:send] = self.fetcher(sstart, send)
  188. return self.cache[start:end]
  189. def __getstate__(self) -> dict[str, Any]:
  190. state = self.__dict__.copy()
  191. # Remove the unpicklable entries.
  192. del state["cache"]
  193. return state
  194. def __setstate__(self, state: dict[str, Any]) -> None:
  195. # Restore instance attributes
  196. self.__dict__.update(state)
  197. self.cache = self._makefile()
  198. class ReadAheadCache(BaseCache):
  199. """Cache which reads only when we get beyond a block of data
  200. This is a much simpler version of BytesCache, and does not attempt to
  201. fill holes in the cache or keep fragments alive. It is best suited to
  202. many small reads in a sequential order (e.g., reading lines from a file).
  203. """
  204. name = "readahead"
  205. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  206. super().__init__(blocksize, fetcher, size)
  207. self.cache = b""
  208. self.start = 0
  209. self.end = 0
  210. def _fetch(self, start: int | None, end: int | None) -> bytes:
  211. if start is None:
  212. start = 0
  213. if end is None or end > self.size:
  214. end = self.size
  215. if start >= self.size or start >= end:
  216. return b""
  217. l = end - start
  218. if start >= self.start and end <= self.end:
  219. # cache hit
  220. self.hit_count += 1
  221. return self.cache[start - self.start : end - self.start]
  222. elif self.start <= start < self.end:
  223. # partial hit
  224. self.miss_count += 1
  225. part = self.cache[start - self.start :]
  226. l -= len(part)
  227. start = self.end
  228. else:
  229. # miss
  230. self.miss_count += 1
  231. part = b""
  232. end = min(self.size, end + self.blocksize)
  233. self.total_requested_bytes += end - start
  234. self.cache = self.fetcher(start, end) # new block replaces old
  235. self.start = start
  236. self.end = self.start + len(self.cache)
  237. return part + self.cache[:l]
  238. class FirstChunkCache(BaseCache):
  239. """Caches the first block of a file only
  240. This may be useful for file types where the metadata is stored in the header,
  241. but is randomly accessed.
  242. """
  243. name = "first"
  244. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  245. if blocksize > size:
  246. # this will buffer the whole thing
  247. blocksize = size
  248. super().__init__(blocksize, fetcher, size)
  249. self.cache: bytes | None = None
  250. def _fetch(self, start: int | None, end: int | None) -> bytes:
  251. start = start or 0
  252. if start > self.size:
  253. logger.debug("FirstChunkCache: requested start > file size")
  254. return b""
  255. end = min(end, self.size)
  256. if start < self.blocksize:
  257. if self.cache is None:
  258. self.miss_count += 1
  259. if end > self.blocksize:
  260. self.total_requested_bytes += end
  261. data = self.fetcher(0, end)
  262. self.cache = data[: self.blocksize]
  263. return data[start:]
  264. self.cache = self.fetcher(0, self.blocksize)
  265. self.total_requested_bytes += self.blocksize
  266. part = self.cache[start:end]
  267. if end > self.blocksize:
  268. self.total_requested_bytes += end - self.blocksize
  269. part += self.fetcher(self.blocksize, end)
  270. self.hit_count += 1
  271. return part
  272. else:
  273. self.miss_count += 1
  274. self.total_requested_bytes += end - start
  275. return self.fetcher(start, end)
  276. class BlockCache(BaseCache):
  277. """
  278. Cache holding memory as a set of blocks.
  279. Requests are only ever made ``blocksize`` at a time, and are
  280. stored in an LRU cache. The least recently accessed block is
  281. discarded when more than ``maxblocks`` are stored.
  282. Parameters
  283. ----------
  284. blocksize : int
  285. The number of bytes to store in each block.
  286. Requests are only ever made for ``blocksize``, so this
  287. should balance the overhead of making a request against
  288. the granularity of the blocks.
  289. fetcher : Callable
  290. size : int
  291. The total size of the file being cached.
  292. maxblocks : int
  293. The maximum number of blocks to cache for. The maximum memory
  294. use for this cache is then ``blocksize * maxblocks``.
  295. """
  296. name = "blockcache"
  297. def __init__(
  298. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  299. ) -> None:
  300. super().__init__(blocksize, fetcher, size)
  301. self.nblocks = math.ceil(size / blocksize)
  302. self.maxblocks = maxblocks
  303. self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)
  304. def cache_info(self):
  305. """
  306. The statistics on the block cache.
  307. Returns
  308. -------
  309. NamedTuple
  310. Returned directly from the LRU Cache used internally.
  311. """
  312. return self._fetch_block_cached.cache_info()
  313. def __getstate__(self) -> dict[str, Any]:
  314. state = self.__dict__
  315. del state["_fetch_block_cached"]
  316. return state
  317. def __setstate__(self, state: dict[str, Any]) -> None:
  318. self.__dict__.update(state)
  319. self._fetch_block_cached = functools.lru_cache(state["maxblocks"])(
  320. self._fetch_block
  321. )
  322. def _fetch(self, start: int | None, end: int | None) -> bytes:
  323. if start is None:
  324. start = 0
  325. if end is None:
  326. end = self.size
  327. if start >= self.size or start >= end:
  328. return b""
  329. # byte position -> block numbers
  330. start_block_number = start // self.blocksize
  331. end_block_number = end // self.blocksize
  332. # these are cached, so safe to do multiple calls for the same start and end.
  333. for block_number in range(start_block_number, end_block_number + 1):
  334. self._fetch_block_cached(block_number)
  335. return self._read_cache(
  336. start,
  337. end,
  338. start_block_number=start_block_number,
  339. end_block_number=end_block_number,
  340. )
  341. def _fetch_block(self, block_number: int) -> bytes:
  342. """
  343. Fetch the block of data for `block_number`.
  344. """
  345. if block_number > self.nblocks:
  346. raise ValueError(
  347. f"'block_number={block_number}' is greater than "
  348. f"the number of blocks ({self.nblocks})"
  349. )
  350. start = block_number * self.blocksize
  351. end = start + self.blocksize
  352. self.total_requested_bytes += end - start
  353. self.miss_count += 1
  354. logger.info("BlockCache fetching block %d", block_number)
  355. block_contents = super()._fetch(start, end)
  356. return block_contents
  357. def _read_cache(
  358. self, start: int, end: int, start_block_number: int, end_block_number: int
  359. ) -> bytes:
  360. """
  361. Read from our block cache.
  362. Parameters
  363. ----------
  364. start, end : int
  365. The start and end byte positions.
  366. start_block_number, end_block_number : int
  367. The start and end block numbers.
  368. """
  369. start_pos = start % self.blocksize
  370. end_pos = end % self.blocksize
  371. self.hit_count += 1
  372. if start_block_number == end_block_number:
  373. block: bytes = self._fetch_block_cached(start_block_number)
  374. return block[start_pos:end_pos]
  375. else:
  376. # read from the initial
  377. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  378. # intermediate blocks
  379. # Note: it'd be nice to combine these into one big request. However
  380. # that doesn't play nicely with our LRU cache.
  381. out.extend(
  382. map(
  383. self._fetch_block_cached,
  384. range(start_block_number + 1, end_block_number),
  385. )
  386. )
  387. # final block
  388. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  389. return b"".join(out)
  390. class BytesCache(BaseCache):
  391. """Cache which holds data in a in-memory bytes object
  392. Implements read-ahead by the block size, for semi-random reads progressing
  393. through the file.
  394. Parameters
  395. ----------
  396. trim: bool
  397. As we read more data, whether to discard the start of the buffer when
  398. we are more than a blocksize ahead of it.
  399. """
  400. name: ClassVar[str] = "bytes"
  401. def __init__(
  402. self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True
  403. ) -> None:
  404. super().__init__(blocksize, fetcher, size)
  405. self.cache = b""
  406. self.start: int | None = None
  407. self.end: int | None = None
  408. self.trim = trim
  409. def _fetch(self, start: int | None, end: int | None) -> bytes:
  410. # TODO: only set start/end after fetch, in case it fails?
  411. # is this where retry logic might go?
  412. if start is None:
  413. start = 0
  414. if end is None:
  415. end = self.size
  416. if start >= self.size or start >= end:
  417. return b""
  418. if (
  419. self.start is not None
  420. and start >= self.start
  421. and self.end is not None
  422. and end < self.end
  423. ):
  424. # cache hit: we have all the required data
  425. offset = start - self.start
  426. self.hit_count += 1
  427. return self.cache[offset : offset + end - start]
  428. if self.blocksize:
  429. bend = min(self.size, end + self.blocksize)
  430. else:
  431. bend = end
  432. if bend == start or start > self.size:
  433. return b""
  434. if (self.start is None or start < self.start) and (
  435. self.end is None or end > self.end
  436. ):
  437. # First read, or extending both before and after
  438. self.total_requested_bytes += bend - start
  439. self.miss_count += 1
  440. self.cache = self.fetcher(start, bend)
  441. self.start = start
  442. else:
  443. assert self.start is not None
  444. assert self.end is not None
  445. self.miss_count += 1
  446. if start < self.start:
  447. if self.end is None or self.end - end > self.blocksize:
  448. self.total_requested_bytes += bend - start
  449. self.cache = self.fetcher(start, bend)
  450. self.start = start
  451. else:
  452. self.total_requested_bytes += self.start - start
  453. new = self.fetcher(start, self.start)
  454. self.start = start
  455. self.cache = new + self.cache
  456. elif self.end is not None and bend > self.end:
  457. if self.end > self.size:
  458. pass
  459. elif end - self.end > self.blocksize:
  460. self.total_requested_bytes += bend - start
  461. self.cache = self.fetcher(start, bend)
  462. self.start = start
  463. else:
  464. self.total_requested_bytes += bend - self.end
  465. new = self.fetcher(self.end, bend)
  466. self.cache = self.cache + new
  467. self.end = self.start + len(self.cache)
  468. offset = start - self.start
  469. out = self.cache[offset : offset + end - start]
  470. if self.trim:
  471. num = (self.end - self.start) // (self.blocksize + 1)
  472. if num > 1:
  473. self.start += self.blocksize * num
  474. self.cache = self.cache[self.blocksize * num :]
  475. return out
  476. def __len__(self) -> int:
  477. return len(self.cache)
  478. class AllBytes(BaseCache):
  479. """Cache entire contents of the file"""
  480. name: ClassVar[str] = "all"
  481. def __init__(
  482. self,
  483. blocksize: int | None = None,
  484. fetcher: Fetcher | None = None,
  485. size: int | None = None,
  486. data: bytes | None = None,
  487. ) -> None:
  488. super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
  489. if data is None:
  490. self.miss_count += 1
  491. self.total_requested_bytes += self.size
  492. data = self.fetcher(0, self.size)
  493. self.data = data
  494. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  495. self.hit_count += 1
  496. return self.data[start:stop]
  497. class KnownPartsOfAFile(BaseCache):
  498. """
  499. Cache holding known file parts.
  500. Parameters
  501. ----------
  502. blocksize: int
  503. How far to read ahead in numbers of bytes
  504. fetcher: func
  505. Function of the form f(start, end) which gets bytes from remote as
  506. specified
  507. size: int
  508. How big this file is
  509. data: dict
  510. A dictionary mapping explicit `(start, stop)` file-offset tuples
  511. with known bytes.
  512. strict: bool, default True
  513. Whether to fetch reads that go beyond a known byte-range boundary.
  514. If `False`, any read that ends outside a known part will be zero
  515. padded. Note that zero padding will not be used for reads that
  516. begin outside a known byte-range.
  517. """
  518. name: ClassVar[str] = "parts"
  519. def __init__(
  520. self,
  521. blocksize: int,
  522. fetcher: Fetcher,
  523. size: int,
  524. data: dict[tuple[int, int], bytes] | None = None,
  525. strict: bool = False,
  526. **_: Any,
  527. ):
  528. super().__init__(blocksize, fetcher, size)
  529. self.strict = strict
  530. # simple consolidation of contiguous blocks
  531. if data:
  532. old_offsets = sorted(data.keys())
  533. offsets = [old_offsets[0]]
  534. blocks = [data.pop(old_offsets[0])]
  535. for start, stop in old_offsets[1:]:
  536. start0, stop0 = offsets[-1]
  537. if start == stop0:
  538. offsets[-1] = (start0, stop)
  539. blocks[-1] += data.pop((start, stop))
  540. else:
  541. offsets.append((start, stop))
  542. blocks.append(data.pop((start, stop)))
  543. self.data = dict(zip(offsets, blocks))
  544. else:
  545. self.data = {}
  546. @property
  547. def size(self):
  548. return sum(_[1] - _[0] for _ in self.data)
  549. @size.setter
  550. def size(self, value):
  551. pass
  552. @property
  553. def nblocks(self):
  554. return len(self.data)
  555. @nblocks.setter
  556. def nblocks(self, value):
  557. pass
  558. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  559. if start is None:
  560. start = 0
  561. if stop is None:
  562. stop = self.size
  563. self.total_requested_bytes += stop - start
  564. out = b""
  565. started = False
  566. loc_old = 0
  567. for loc0, loc1 in sorted(self.data):
  568. if (loc0 <= start < loc1) and (loc0 <= stop <= loc1):
  569. # entirely within the block
  570. off = start - loc0
  571. self.hit_count += 1
  572. return self.data[(loc0, loc1)][off : off + stop - start]
  573. if stop <= loc0:
  574. break
  575. if started and loc0 > loc_old:
  576. # a gap where we need data
  577. self.miss_count += 1
  578. if self.strict:
  579. raise ValueError
  580. out += b"\x00" * (loc0 - loc_old)
  581. if loc0 <= start < loc1:
  582. # found the start
  583. self.hit_count += 1
  584. off = start - loc0
  585. out = self.data[(loc0, loc1)][off : off + stop - start]
  586. started = True
  587. elif start < loc0 and stop > loc1:
  588. # the whole block
  589. self.hit_count += 1
  590. out += self.data[(loc0, loc1)]
  591. elif loc0 <= stop <= loc1:
  592. # end block
  593. self.hit_count += 1
  594. return out + self.data[(loc0, loc1)][: stop - loc0]
  595. loc_old = loc1
  596. self.miss_count += 1
  597. if started and not self.strict:
  598. return out + b"\x00" * (stop - loc_old)
  599. raise ValueError
  600. class UpdatableLRU(Generic[P, T]):
  601. """
  602. Custom implementation of LRU cache that allows updating keys
  603. Used by BackgroudBlockCache
  604. """
  605. class CacheInfo(NamedTuple):
  606. hits: int
  607. misses: int
  608. maxsize: int
  609. currsize: int
  610. def __init__(self, func: Callable[P, T], max_size: int = 128) -> None:
  611. self._cache: OrderedDict[Any, T] = collections.OrderedDict()
  612. self._func = func
  613. self._max_size = max_size
  614. self._hits = 0
  615. self._misses = 0
  616. self._lock = threading.Lock()
  617. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  618. if kwargs:
  619. raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}")
  620. with self._lock:
  621. if args in self._cache:
  622. self._cache.move_to_end(args)
  623. self._hits += 1
  624. return self._cache[args]
  625. result = self._func(*args, **kwargs)
  626. with self._lock:
  627. self._cache[args] = result
  628. self._misses += 1
  629. if len(self._cache) > self._max_size:
  630. self._cache.popitem(last=False)
  631. return result
  632. def is_key_cached(self, *args: Any) -> bool:
  633. with self._lock:
  634. return args in self._cache
  635. def add_key(self, result: T, *args: Any) -> None:
  636. with self._lock:
  637. self._cache[args] = result
  638. if len(self._cache) > self._max_size:
  639. self._cache.popitem(last=False)
  640. def cache_info(self) -> UpdatableLRU.CacheInfo:
  641. with self._lock:
  642. return self.CacheInfo(
  643. maxsize=self._max_size,
  644. currsize=len(self._cache),
  645. hits=self._hits,
  646. misses=self._misses,
  647. )
  648. class BackgroundBlockCache(BaseCache):
  649. """
  650. Cache holding memory as a set of blocks with pre-loading of
  651. the next block in the background.
  652. Requests are only ever made ``blocksize`` at a time, and are
  653. stored in an LRU cache. The least recently accessed block is
  654. discarded when more than ``maxblocks`` are stored. If the
  655. next block is not in cache, it is loaded in a separate thread
  656. in non-blocking way.
  657. Parameters
  658. ----------
  659. blocksize : int
  660. The number of bytes to store in each block.
  661. Requests are only ever made for ``blocksize``, so this
  662. should balance the overhead of making a request against
  663. the granularity of the blocks.
  664. fetcher : Callable
  665. size : int
  666. The total size of the file being cached.
  667. maxblocks : int
  668. The maximum number of blocks to cache for. The maximum memory
  669. use for this cache is then ``blocksize * maxblocks``.
  670. """
  671. name: ClassVar[str] = "background"
  672. def __init__(
  673. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  674. ) -> None:
  675. super().__init__(blocksize, fetcher, size)
  676. self.nblocks = math.ceil(size / blocksize)
  677. self.maxblocks = maxblocks
  678. self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks)
  679. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  680. self._fetch_future_block_number: int | None = None
  681. self._fetch_future: Future[bytes] | None = None
  682. self._fetch_future_lock = threading.Lock()
  683. def cache_info(self) -> UpdatableLRU.CacheInfo:
  684. """
  685. The statistics on the block cache.
  686. Returns
  687. -------
  688. NamedTuple
  689. Returned directly from the LRU Cache used internally.
  690. """
  691. return self._fetch_block_cached.cache_info()
  692. def __getstate__(self) -> dict[str, Any]:
  693. state = self.__dict__
  694. del state["_fetch_block_cached"]
  695. del state["_thread_executor"]
  696. del state["_fetch_future_block_number"]
  697. del state["_fetch_future"]
  698. del state["_fetch_future_lock"]
  699. return state
  700. def __setstate__(self, state) -> None:
  701. self.__dict__.update(state)
  702. self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"])
  703. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  704. self._fetch_future_block_number = None
  705. self._fetch_future = None
  706. self._fetch_future_lock = threading.Lock()
  707. def _fetch(self, start: int | None, end: int | None) -> bytes:
  708. if start is None:
  709. start = 0
  710. if end is None:
  711. end = self.size
  712. if start >= self.size or start >= end:
  713. return b""
  714. # byte position -> block numbers
  715. start_block_number = start // self.blocksize
  716. end_block_number = end // self.blocksize
  717. fetch_future_block_number = None
  718. fetch_future = None
  719. with self._fetch_future_lock:
  720. # Background thread is running. Check we we can or must join it.
  721. if self._fetch_future is not None:
  722. assert self._fetch_future_block_number is not None
  723. if self._fetch_future.done():
  724. logger.info("BlockCache joined background fetch without waiting.")
  725. self._fetch_block_cached.add_key(
  726. self._fetch_future.result(), self._fetch_future_block_number
  727. )
  728. # Cleanup the fetch variables. Done with fetching the block.
  729. self._fetch_future_block_number = None
  730. self._fetch_future = None
  731. else:
  732. # Must join if we need the block for the current fetch
  733. must_join = bool(
  734. start_block_number
  735. <= self._fetch_future_block_number
  736. <= end_block_number
  737. )
  738. if must_join:
  739. # Copy to the local variables to release lock
  740. # before waiting for result
  741. fetch_future_block_number = self._fetch_future_block_number
  742. fetch_future = self._fetch_future
  743. # Cleanup the fetch variables. Have a local copy.
  744. self._fetch_future_block_number = None
  745. self._fetch_future = None
  746. # Need to wait for the future for the current read
  747. if fetch_future is not None:
  748. logger.info("BlockCache waiting for background fetch.")
  749. # Wait until result and put it in cache
  750. self._fetch_block_cached.add_key(
  751. fetch_future.result(), fetch_future_block_number
  752. )
  753. # these are cached, so safe to do multiple calls for the same start and end.
  754. for block_number in range(start_block_number, end_block_number + 1):
  755. self._fetch_block_cached(block_number)
  756. # fetch next block in the background if nothing is running in the background,
  757. # the block is within file and it is not already cached
  758. end_block_plus_1 = end_block_number + 1
  759. with self._fetch_future_lock:
  760. if (
  761. self._fetch_future is None
  762. and end_block_plus_1 <= self.nblocks
  763. and not self._fetch_block_cached.is_key_cached(end_block_plus_1)
  764. ):
  765. self._fetch_future_block_number = end_block_plus_1
  766. self._fetch_future = self._thread_executor.submit(
  767. self._fetch_block, end_block_plus_1, "async"
  768. )
  769. return self._read_cache(
  770. start,
  771. end,
  772. start_block_number=start_block_number,
  773. end_block_number=end_block_number,
  774. )
  775. def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
  776. """
  777. Fetch the block of data for `block_number`.
  778. """
  779. if block_number > self.nblocks:
  780. raise ValueError(
  781. f"'block_number={block_number}' is greater than "
  782. f"the number of blocks ({self.nblocks})"
  783. )
  784. start = block_number * self.blocksize
  785. end = start + self.blocksize
  786. logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
  787. self.total_requested_bytes += end - start
  788. self.miss_count += 1
  789. block_contents = super()._fetch(start, end)
  790. return block_contents
  791. def _read_cache(
  792. self, start: int, end: int, start_block_number: int, end_block_number: int
  793. ) -> bytes:
  794. """
  795. Read from our block cache.
  796. Parameters
  797. ----------
  798. start, end : int
  799. The start and end byte positions.
  800. start_block_number, end_block_number : int
  801. The start and end block numbers.
  802. """
  803. start_pos = start % self.blocksize
  804. end_pos = end % self.blocksize
  805. # kind of pointless to count this as a hit, but it is
  806. self.hit_count += 1
  807. if start_block_number == end_block_number:
  808. block = self._fetch_block_cached(start_block_number)
  809. return block[start_pos:end_pos]
  810. else:
  811. # read from the initial
  812. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  813. # intermediate blocks
  814. # Note: it'd be nice to combine these into one big request. However
  815. # that doesn't play nicely with our LRU cache.
  816. out.extend(
  817. map(
  818. self._fetch_block_cached,
  819. range(start_block_number + 1, end_block_number),
  820. )
  821. )
  822. # final block
  823. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  824. return b"".join(out)
  825. caches: dict[str | None, type[BaseCache]] = {
  826. # one custom case
  827. None: BaseCache,
  828. }
  829. def register_cache(cls: type[BaseCache], clobber: bool = False) -> None:
  830. """'Register' cache implementation.
  831. Parameters
  832. ----------
  833. clobber: bool, optional
  834. If set to True (default is False) - allow to overwrite existing
  835. entry.
  836. Raises
  837. ------
  838. ValueError
  839. """
  840. name = cls.name
  841. if not clobber and name in caches:
  842. raise ValueError(f"Cache with name {name!r} is already known: {caches[name]}")
  843. caches[name] = cls
  844. for c in (
  845. BaseCache,
  846. MMapCache,
  847. BytesCache,
  848. ReadAheadCache,
  849. BlockCache,
  850. FirstChunkCache,
  851. AllBytes,
  852. KnownPartsOfAFile,
  853. BackgroundBlockCache,
  854. ):
  855. register_cache(c)