thread_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from concurrent.futures import ThreadPoolExecutor, as_completed
  3. from functools import wraps
  4. from tqdm.auto import tqdm
  5. from modelscope.hub.constants import DEFAULT_MAX_WORKERS
  6. from modelscope.utils.logger import get_logger
  7. logger = get_logger()
  8. def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
  9. disable_tqdm: bool = False,
  10. tqdm_desc: str = None):
  11. """
  12. A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
  13. Args:
  14. max_workers (int): The maximum number of threads to use.
  15. disable_tqdm (bool): disable progress bar.
  16. tqdm_desc (str): Desc of tqdm.
  17. Returns:
  18. function: A wrapped function that executes with threading and a progress bar.
  19. Examples:
  20. >>> from modelscope.utils.thread_utils import thread_executor
  21. >>> import time
  22. >>> @thread_executor(max_workers=8)
  23. ... def process_item(item, x, y):
  24. ... # do something to single item
  25. ... time.sleep(1)
  26. ... return str(item) + str(x) + str(y)
  27. >>> items = [1, 2, 3]
  28. >>> process_item(items, x='abc', y='xyz')
  29. """
  30. def decorator(func):
  31. @wraps(func)
  32. def wrapper(iterable, *args, **kwargs):
  33. results = []
  34. # Create a tqdm progress bar with the total number of items to process
  35. with tqdm(
  36. unit_scale=True,
  37. unit_divisor=1024,
  38. initial=0,
  39. total=len(iterable),
  40. desc=tqdm_desc or f'Processing {len(iterable)} items',
  41. disable=disable_tqdm,
  42. ) as pbar:
  43. # Define a wrapper function to update the progress bar
  44. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  45. # Submit all tasks
  46. futures = {
  47. executor.submit(func, item, *args, **kwargs): item
  48. for item in iterable
  49. }
  50. # Update the progress bar as tasks complete
  51. for future in as_completed(futures):
  52. pbar.update(1)
  53. results.append(future.result())
  54. return results
  55. return wrapper
  56. return decorator