utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. import base64
  2. import hashlib
  3. import math
  4. import os
  5. import re
  6. from collections.abc import Mapping
  7. from copy import deepcopy
  8. from io import BytesIO
  9. from typing import Any, Callable, List, TypeVar, Union, Tuple, Set, Dict, Type, Optional, Sequence
  10. import numpy as np
  11. import requests
  12. import torch
  13. from packaging import version
  14. History = List[Union[Tuple[str, str], List[str]]]
  15. Prompt = List[Union[str, List[int], List[str]]]
  16. StopWords = Prompt
  17. Context = Union[str, List[int]]
  18. Messages = List[Dict[str, Union[str, List[Dict]]]]
  19. # >>> internvl
  20. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  21. IMAGENET_STD = (0.229, 0.224, 0.225)
  22. def split_str_parts_by(text: str, delimiters: List[str]):
  23. """Split the text field into parts.
  24. Args:
  25. text: A text to be split.
  26. delimiters: The delimiters.
  27. Returns:
  28. The split text in list of dicts.
  29. """
  30. assert isinstance(text, str), f'text: {text}'
  31. all_start_chars = [d[0] for d in delimiters]
  32. all_length = [len(d) for d in delimiters]
  33. text_list = []
  34. last_words = ''
  35. while len(text) > 0:
  36. for char_idx, char in enumerate(text):
  37. match_index = [idx for idx, start_char in enumerate(all_start_chars) if start_char == char]
  38. is_delimiter = False
  39. for index in match_index:
  40. if text[char_idx:char_idx + all_length[index]] == delimiters[index]:
  41. if text_list:
  42. text_list[-1]['content'] = last_words
  43. elif last_words:
  44. text_list.append({'key': '', 'content': last_words})
  45. last_words = ''
  46. text_list.append({'key': delimiters[index]})
  47. text = text[char_idx + all_length[index]:]
  48. is_delimiter = True
  49. break
  50. if not is_delimiter:
  51. last_words += char
  52. else:
  53. break
  54. if last_words == text:
  55. text = ''
  56. if len(text_list):
  57. text_list[-1]['content'] = last_words
  58. else:
  59. text_list.append({'key': '', 'content': last_words})
  60. return text_list
  61. def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]) -> None:
  62. import re
  63. compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
  64. for i in range(len(text_list) - 1, -1, -1):
  65. item = text_list[i]
  66. if item.get('key') == '':
  67. res_text = item['content']
  68. last_idx = 0
  69. segments = []
  70. for pattern, scale in compiled_patterns:
  71. matches = list(re.finditer(pattern, res_text))
  72. for match in matches:
  73. if match.start() > last_idx:
  74. segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
  75. segments.append({'key': scale[0], 'content': match.group(0)})
  76. last_idx = match.end()
  77. if last_idx < len(res_text):
  78. segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
  79. if segments:
  80. text_list[i:i + 1] = segments
  81. def _decode_prompt(prompt: str, tmp_dir: str = 'tmp') -> str:
  82. pattern = r'<(?:img|audio|video)>(.+?)</(?:img|audio|video)>'
  83. match_iter = re.finditer(pattern, prompt)
  84. new_content = ''
  85. idx = 0
  86. for m in match_iter:
  87. span = m.span(1)
  88. img_base64 = m.group(1)
  89. img_path = _from_base64(img_base64, tmp_dir)
  90. new_content += prompt[idx:span[0]] + img_path
  91. idx = span[1]
  92. new_content += prompt[idx:]
  93. return new_content
  94. def _to_base64(img_path: Union[str, 'PIL.Image.Image', bytes]) -> str:
  95. if isinstance(img_path, str) and not os.path.isfile(img_path):
  96. # base64
  97. return img_path
  98. if isinstance(img_path, str):
  99. # local_path
  100. with open(img_path, 'rb') as f:
  101. _bytes = f.read()
  102. elif not isinstance(img_path, bytes): # PIL.Image.Image
  103. bytes_io = BytesIO()
  104. img_path.save(bytes_io, format='png')
  105. _bytes = bytes_io.getvalue()
  106. else:
  107. _bytes = img_path
  108. img_base64: str = base64.b64encode(_bytes).decode('utf-8')
  109. return img_base64
  110. def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp') -> str:
  111. from PIL import Image
  112. if not isinstance(img_base64, str): # PIL.Image.Image
  113. img_base64 = _to_base64(img_base64)
  114. if os.path.isfile(img_base64) or img_base64.startswith('http'):
  115. return img_base64
  116. sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
  117. img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
  118. image = Image.open(BytesIO(base64.b64decode(img_base64)))
  119. if not os.path.exists(img_path):
  120. image.save(img_path)
  121. return img_path
  122. def decode_base64(*,
  123. messages: Optional[Messages] = None,
  124. prompt: Optional[str] = None,
  125. images: Optional[List[str]] = None,
  126. tmp_dir: str = 'tmp') -> Dict[str, Any]:
  127. # base64 -> local_path
  128. os.makedirs(tmp_dir, exist_ok=True)
  129. res = {}
  130. if messages is not None:
  131. res_messages = []
  132. for m in messages:
  133. m_new = deepcopy(m)
  134. m_new['content'] = _decode_prompt(m_new['content'], tmp_dir)
  135. res_messages.append(m_new)
  136. res['messages'] = res_messages
  137. if prompt is not None:
  138. prompt = _decode_prompt(prompt, tmp_dir)
  139. res['prompt'] = prompt
  140. if images is not None:
  141. res_images = []
  142. for image in images:
  143. image = _from_base64(image, tmp_dir)
  144. res_images.append(image)
  145. res['images'] = res_images
  146. return res
  147. def to_device(inputs: Any, device: torch.device) -> Any:
  148. """Move inputs to a device"""
  149. if callable(getattr(inputs, 'to', None)):
  150. return inputs.to(device=device)
  151. if isinstance(inputs, Mapping):
  152. res = {}
  153. for k, v in inputs.items():
  154. res[k] = to_device(v, device)
  155. elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
  156. res = []
  157. for b in inputs:
  158. res.append(to_device(b, device))
  159. else:
  160. res = inputs
  161. return res
  162. def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
  163. # The upper bound satisfying the condition "cond".
  164. while lo < hi:
  165. mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
  166. if cond(mid):
  167. lo = mid
  168. else:
  169. hi = mid - 1
  170. return lo
  171. def fetch_one(element: Union[Tuple, List, Set, Dict, Any], type: Type = None) -> Any:
  172. if isinstance(element, (tuple, set, list)):
  173. for ele in element:
  174. out = fetch_one(ele)
  175. if out and (type is None or isinstance(out, type)):
  176. return out
  177. elif isinstance(element, dict):
  178. return fetch_one(list(element.values()))
  179. else:
  180. return element
  181. def _build_transform(input_size):
  182. import torchvision.transforms as T
  183. from torchvision.transforms.functional import InterpolationMode
  184. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  185. transform = T.Compose([
  186. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  187. T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
  188. T.ToTensor(),
  189. T.Normalize(mean=MEAN, std=STD)
  190. ])
  191. return transform
  192. def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
  193. best_ratio_diff = float('inf')
  194. best_ratio = (1, 1)
  195. area = width * height
  196. for ratio in target_ratios:
  197. target_aspect_ratio = ratio[0] / ratio[1]
  198. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  199. if ratio_diff < best_ratio_diff:
  200. best_ratio_diff = ratio_diff
  201. best_ratio = ratio
  202. elif ratio_diff == best_ratio_diff:
  203. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  204. best_ratio = ratio
  205. return best_ratio
  206. def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
  207. orig_width, orig_height = image.size
  208. aspect_ratio = orig_width / orig_height
  209. # calculate the existing image aspect ratio
  210. target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
  211. if i * j <= max_num and i * j >= min_num)
  212. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  213. # find the closest aspect ratio to the target
  214. target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
  215. # calculate the target width and height
  216. target_width = image_size * target_aspect_ratio[0]
  217. target_height = image_size * target_aspect_ratio[1]
  218. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  219. # resize the image
  220. resized_img = image.resize((target_width, target_height))
  221. processed_images = []
  222. for i in range(blocks):
  223. box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
  224. ((i % (target_width // image_size)) + 1) * image_size, ((i //
  225. (target_width // image_size)) + 1) * image_size)
  226. # split the image
  227. split_img = resized_img.crop(box)
  228. processed_images.append(split_img)
  229. assert len(processed_images) == blocks
  230. if use_thumbnail and len(processed_images) != 1:
  231. thumbnail_img = image.resize((image_size, image_size))
  232. processed_images.append(thumbnail_img)
  233. return processed_images
  234. # <<< internvl
  235. def rescale_image(img: 'PIL.Image.Image', rescale_image: int = -1) -> 'PIL.Image.Image':
  236. import torchvision.transforms as T
  237. width = img.width
  238. height = img.height
  239. if rescale_image <= 0 or width * height <= rescale_image:
  240. return img
  241. ratio = width / height
  242. height_scaled = math.pow(rescale_image / ratio, 0.5)
  243. width_scaled = height_scaled * ratio
  244. return T.Resize((int(width_scaled), int(height_scaled)))(img)
  245. _T = TypeVar('_T')
  246. def load_file(path: Union[str, _T]) -> Union[BytesIO, _T]:
  247. res = path
  248. if isinstance(path, str):
  249. path = path.strip()
  250. if path.startswith('http'):
  251. request_kwargs = {}
  252. timeout = float(os.getenv('TIMEOUT', '60'))
  253. if timeout > 0:
  254. request_kwargs['timeout'] = timeout
  255. content = requests.get(path, **request_kwargs).content
  256. res = BytesIO(content)
  257. elif os.path.exists(path):
  258. with open(path, 'rb') as f:
  259. res = BytesIO(f.read())
  260. else: # base64_str
  261. import binascii
  262. try:
  263. data = base64.b64decode(path)
  264. res = BytesIO(data)
  265. except (ValueError, binascii.Error) as error:
  266. if len(path) < 200:
  267. raise ValueError(f'invalid image: "{path}"')
  268. else:
  269. raise ValueError(f'invalid image: {error}')
  270. return res
  271. def load_file_decorator(func):
  272. def new_func(path, *args, **kwargs):
  273. path = load_file(path)
  274. res = func(path, *args, **kwargs)
  275. return res
  276. return new_func
  277. @load_file_decorator
  278. def load_image(image: Union['PIL.Image.Image', BytesIO]) -> 'PIL.Image.Image':
  279. from PIL import Image
  280. if isinstance(image, BytesIO):
  281. image = Image.open(image)
  282. if image.mode != 'RGB':
  283. image = image.convert('RGB')
  284. return image
  285. def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
  286. load_func: Callable[[Any], _T] = load_image) -> List[_T]:
  287. res = []
  288. assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
  289. for path in path_list:
  290. if path is None: # ignore None
  291. continue
  292. res.append(load_func(path))
  293. return res
  294. def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
  295. if bound:
  296. start, end = bound[0], bound[1]
  297. else:
  298. start, end = -100000, 100000
  299. start_idx = max(first_idx, round(start * fps))
  300. end_idx = min(round(end * fps), max_frame)
  301. seg_size = float(end_idx - start_idx) / num_segments
  302. frame_indices = np.array(
  303. [int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
  304. return frame_indices
  305. def transform_image(image, input_size=448, max_num=12):
  306. transform = _build_transform(input_size=input_size)
  307. images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
  308. pixel_values = [transform(image) for image in images]
  309. pixel_values = torch.stack(pixel_values)
  310. return pixel_values
  311. @load_file_decorator
  312. def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
  313. from decord import VideoReader, cpu
  314. from PIL import Image
  315. vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
  316. max_frame = len(vr) - 1
  317. fps = float(vr.get_avg_fps())
  318. images = []
  319. frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
  320. for frame_index in frame_indices:
  321. images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
  322. return images
  323. def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
  324. from PIL import Image, ImageDraw
  325. from swift.llm.template.template import Template
  326. image = Image.open(img_dir)
  327. objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
  328. Template.normalize_bbox(objects, [image], 'real')
  329. bbox = objects[0]['bbox']
  330. draw = ImageDraw.Draw(image)
  331. draw.rectangle(bbox, outline='red', width=2)
  332. image.save(output_file)
  333. @load_file_decorator
  334. def load_video_cogvlm2(video_io: BytesIO) -> np.ndarray:
  335. from decord import cpu, VideoReader, bridge
  336. bridge.set_bridge('torch')
  337. clip_end_sec = 60
  338. clip_start_sec = 0
  339. num_frames = 24
  340. decord_vr = VideoReader(video_io, ctx=cpu(0))
  341. duration = len(decord_vr) # duration in terms of frames
  342. start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
  343. end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
  344. clip_end_sec is not None else duration
  345. frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
  346. video_data = decord_vr.get_batch(frame_id_list)
  347. video_data = video_data.permute(3, 0, 1, 2)
  348. return video_data
  349. @load_file_decorator
  350. def load_video_llava(video_io: BytesIO) -> np.ndarray:
  351. import av
  352. container = av.open(video_io)
  353. total_frames = container.streams.video[0].frames
  354. indices = np.arange(0, total_frames, total_frames / 8).astype(int)
  355. frames = []
  356. container.seek(0)
  357. start_index = indices[0]
  358. end_index = indices[-1]
  359. for i, frame in enumerate(container.decode(video=0)):
  360. if i > end_index:
  361. break
  362. if i >= start_index and i in indices:
  363. frames.append(frame)
  364. return np.stack([x.to_ndarray(format='rgb24') for x in frames])
  365. @load_file_decorator
  366. def load_video_minicpmv_mplug_owl3(video_io: BytesIO, max_num_frames):
  367. from PIL import Image
  368. from decord import VideoReader, cpu # pip install decord
  369. def uniform_sample(_l, _n):
  370. gap = len(_l) / _n
  371. idxs = [int(i * gap + gap / 2) for i in range(_n)]
  372. return [_l[i] for i in idxs]
  373. vr = VideoReader(video_io, ctx=cpu(0))
  374. sample_fps = round(vr.get_avg_fps() / 1) # FPS
  375. frame_idx = [i for i in range(0, len(vr), sample_fps)]
  376. if len(frame_idx) > max_num_frames:
  377. frame_idx = uniform_sample(frame_idx, max_num_frames)
  378. frames = vr.get_batch(frame_idx).asnumpy()
  379. frames = [Image.fromarray(v.astype('uint8')) for v in frames]
  380. return frames
  381. @load_file_decorator
  382. def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
  383. import librosa
  384. return librosa.load(audio_io, sr=sampling_rate)[0]
  385. def load_video_qwen2(video_path: str):
  386. from swift.llm.template.template import get_env_args
  387. import torchvision
  388. from torchvision import io, transforms
  389. from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
  390. VIDEO_MIN_PIXELS, VIDEO_MAX_PIXELS, VIDEO_TOTAL_PIXELS, smart_resize,
  391. ceil_by_factor, floor_by_factor)
  392. from torchvision.transforms import InterpolationMode
  393. if version.parse(torchvision.__version__) >= version.parse('0.19'):
  394. video_path = load_file(video_path)
  395. video, _, info = io.read_video(
  396. video_path,
  397. pts_unit='sec',
  398. output_format='TCHW',
  399. )
  400. nframes = get_env_args('nframes', int, None)
  401. fps = get_env_args('fps', int, None)
  402. size_factor = get_env_args('size_factor', int, FRAME_FACTOR)
  403. assert not (fps and nframes), 'Only accept either `fps` or `nframes`'
  404. if nframes is not None:
  405. nframes = round_by_factor(nframes, size_factor)
  406. else:
  407. fps = FPS
  408. nframes = video.size(0) / info['video_fps'] * fps
  409. nframes = round_by_factor(nframes, size_factor)
  410. min_frames = get_env_args('min_frames', int, FPS_MIN_FRAMES)
  411. max_frames = get_env_args('max_frames', int, FPS_MAX_FRAMES)
  412. if nframes < min_frames:
  413. nframes = ceil_by_factor(min_frames, size_factor)
  414. if nframes > max_frames:
  415. nframes = floor_by_factor(max_frames, size_factor)
  416. if not (size_factor <= nframes and nframes <= video.size(0)):
  417. raise ValueError(f'nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.')
  418. idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
  419. height, width = video.shape[2:]
  420. video = video[idx]
  421. min_pixels = get_env_args('min_pixels', int, VIDEO_MIN_PIXELS)
  422. total_pixels = get_env_args('total_pixels', int, VIDEO_TOTAL_PIXELS)
  423. max_pixels = get_env_args('max_pixels', int, None)
  424. if max_pixels is None:
  425. max_pixels = VIDEO_MAX_PIXELS
  426. max_pixels = max(min(max_pixels, total_pixels / nframes * size_factor), min_pixels * 1.05)
  427. # resize
  428. resized_height = get_env_args('resized_height', int, None)
  429. resized_width = get_env_args('resized_width', int, None)
  430. if resized_height and resized_width:
  431. resized_height, resized_width = smart_resize(
  432. resized_height,
  433. resized_width,
  434. factor=size_factor,
  435. )
  436. else:
  437. resized_height, resized_width = smart_resize(
  438. height,
  439. width,
  440. factor=size_factor,
  441. min_pixels=min_pixels,
  442. max_pixels=max_pixels,
  443. )
  444. video = transforms.functional.resize(
  445. video,
  446. [resized_height, resized_width],
  447. interpolation=InterpolationMode.BICUBIC,
  448. antialias=True,
  449. ).float()
  450. return video
  451. if __name__ == '__main__':
  452. # A test main to draw bbox
  453. draw_plot('man.jpg', [354, 462, 580, 738], 'norm_1000', 'man_bbox.jpg')