| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- import base64
- import hashlib
- import math
- import os
- import re
- from collections.abc import Mapping
- from copy import deepcopy
- from io import BytesIO
- from typing import Any, Callable, List, TypeVar, Union, Tuple, Set, Dict, Type, Optional, Sequence
- import numpy as np
- import requests
- import torch
- from packaging import version
- History = List[Union[Tuple[str, str], List[str]]]
- Prompt = List[Union[str, List[int], List[str]]]
- StopWords = Prompt
- Context = Union[str, List[int]]
- Messages = List[Dict[str, Union[str, List[Dict]]]]
- # >>> internvl
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
- IMAGENET_STD = (0.229, 0.224, 0.225)
- def split_str_parts_by(text: str, delimiters: List[str]):
- """Split the text field into parts.
- Args:
- text: A text to be split.
- delimiters: The delimiters.
- Returns:
- The split text in list of dicts.
- """
- assert isinstance(text, str), f'text: {text}'
- all_start_chars = [d[0] for d in delimiters]
- all_length = [len(d) for d in delimiters]
- text_list = []
- last_words = ''
- while len(text) > 0:
- for char_idx, char in enumerate(text):
- match_index = [idx for idx, start_char in enumerate(all_start_chars) if start_char == char]
- is_delimiter = False
- for index in match_index:
- if text[char_idx:char_idx + all_length[index]] == delimiters[index]:
- if text_list:
- text_list[-1]['content'] = last_words
- elif last_words:
- text_list.append({'key': '', 'content': last_words})
- last_words = ''
- text_list.append({'key': delimiters[index]})
- text = text[char_idx + all_length[index]:]
- is_delimiter = True
- break
- if not is_delimiter:
- last_words += char
- else:
- break
- if last_words == text:
- text = ''
- if len(text_list):
- text_list[-1]['content'] = last_words
- else:
- text_list.append({'key': '', 'content': last_words})
- return text_list
- def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]) -> None:
- import re
- compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
- for i in range(len(text_list) - 1, -1, -1):
- item = text_list[i]
- if item.get('key') == '':
- res_text = item['content']
- last_idx = 0
- segments = []
- for pattern, scale in compiled_patterns:
- matches = list(re.finditer(pattern, res_text))
- for match in matches:
- if match.start() > last_idx:
- segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
- segments.append({'key': scale[0], 'content': match.group(0)})
- last_idx = match.end()
- if last_idx < len(res_text):
- segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
- if segments:
- text_list[i:i + 1] = segments
- def _decode_prompt(prompt: str, tmp_dir: str = 'tmp') -> str:
- pattern = r'<(?:img|audio|video)>(.+?)</(?:img|audio|video)>'
- match_iter = re.finditer(pattern, prompt)
- new_content = ''
- idx = 0
- for m in match_iter:
- span = m.span(1)
- img_base64 = m.group(1)
- img_path = _from_base64(img_base64, tmp_dir)
- new_content += prompt[idx:span[0]] + img_path
- idx = span[1]
- new_content += prompt[idx:]
- return new_content
- def _to_base64(img_path: Union[str, 'PIL.Image.Image', bytes]) -> str:
- if isinstance(img_path, str) and not os.path.isfile(img_path):
- # base64
- return img_path
- if isinstance(img_path, str):
- # local_path
- with open(img_path, 'rb') as f:
- _bytes = f.read()
- elif not isinstance(img_path, bytes): # PIL.Image.Image
- bytes_io = BytesIO()
- img_path.save(bytes_io, format='png')
- _bytes = bytes_io.getvalue()
- else:
- _bytes = img_path
- img_base64: str = base64.b64encode(_bytes).decode('utf-8')
- return img_base64
- def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp') -> str:
- from PIL import Image
- if not isinstance(img_base64, str): # PIL.Image.Image
- img_base64 = _to_base64(img_base64)
- if os.path.isfile(img_base64) or img_base64.startswith('http'):
- return img_base64
- sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
- img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
- image = Image.open(BytesIO(base64.b64decode(img_base64)))
- if not os.path.exists(img_path):
- image.save(img_path)
- return img_path
- def decode_base64(*,
- messages: Optional[Messages] = None,
- prompt: Optional[str] = None,
- images: Optional[List[str]] = None,
- tmp_dir: str = 'tmp') -> Dict[str, Any]:
- # base64 -> local_path
- os.makedirs(tmp_dir, exist_ok=True)
- res = {}
- if messages is not None:
- res_messages = []
- for m in messages:
- m_new = deepcopy(m)
- m_new['content'] = _decode_prompt(m_new['content'], tmp_dir)
- res_messages.append(m_new)
- res['messages'] = res_messages
- if prompt is not None:
- prompt = _decode_prompt(prompt, tmp_dir)
- res['prompt'] = prompt
- if images is not None:
- res_images = []
- for image in images:
- image = _from_base64(image, tmp_dir)
- res_images.append(image)
- res['images'] = res_images
- return res
- def to_device(inputs: Any, device: torch.device) -> Any:
- """Move inputs to a device"""
- if callable(getattr(inputs, 'to', None)):
- return inputs.to(device=device)
- if isinstance(inputs, Mapping):
- res = {}
- for k, v in inputs.items():
- res[k] = to_device(v, device)
- elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
- res = []
- for b in inputs:
- res.append(to_device(b, device))
- else:
- res = inputs
- return res
- def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
- # The upper bound satisfying the condition "cond".
- while lo < hi:
- mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
- if cond(mid):
- lo = mid
- else:
- hi = mid - 1
- return lo
- def fetch_one(element: Union[Tuple, List, Set, Dict, Any], type: Type = None) -> Any:
- if isinstance(element, (tuple, set, list)):
- for ele in element:
- out = fetch_one(ele)
- if out and (type is None or isinstance(out, type)):
- return out
- elif isinstance(element, dict):
- return fetch_one(list(element.values()))
- else:
- return element
- def _build_transform(input_size):
- import torchvision.transforms as T
- from torchvision.transforms.functional import InterpolationMode
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
- T.ToTensor(),
- T.Normalize(mean=MEAN, std=STD)
- ])
- return transform
- def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
- best_ratio_diff = float('inf')
- best_ratio = (1, 1)
- area = width * height
- for ratio in target_ratios:
- target_aspect_ratio = ratio[0] / ratio[1]
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
- if ratio_diff < best_ratio_diff:
- best_ratio_diff = ratio_diff
- best_ratio = ratio
- elif ratio_diff == best_ratio_diff:
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
- best_ratio = ratio
- return best_ratio
- def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
- orig_width, orig_height = image.size
- aspect_ratio = orig_width / orig_height
- # calculate the existing image aspect ratio
- target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
- if i * j <= max_num and i * j >= min_num)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
- # find the closest aspect ratio to the target
- target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
- # calculate the target width and height
- target_width = image_size * target_aspect_ratio[0]
- target_height = image_size * target_aspect_ratio[1]
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
- # resize the image
- resized_img = image.resize((target_width, target_height))
- processed_images = []
- for i in range(blocks):
- box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
- ((i % (target_width // image_size)) + 1) * image_size, ((i //
- (target_width // image_size)) + 1) * image_size)
- # split the image
- split_img = resized_img.crop(box)
- processed_images.append(split_img)
- assert len(processed_images) == blocks
- if use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((image_size, image_size))
- processed_images.append(thumbnail_img)
- return processed_images
- # <<< internvl
- def rescale_image(img: 'PIL.Image.Image', rescale_image: int = -1) -> 'PIL.Image.Image':
- import torchvision.transforms as T
- width = img.width
- height = img.height
- if rescale_image <= 0 or width * height <= rescale_image:
- return img
- ratio = width / height
- height_scaled = math.pow(rescale_image / ratio, 0.5)
- width_scaled = height_scaled * ratio
- return T.Resize((int(width_scaled), int(height_scaled)))(img)
- _T = TypeVar('_T')
- def load_file(path: Union[str, _T]) -> Union[BytesIO, _T]:
- res = path
- if isinstance(path, str):
- path = path.strip()
- if path.startswith('http'):
- request_kwargs = {}
- timeout = float(os.getenv('TIMEOUT', '60'))
- if timeout > 0:
- request_kwargs['timeout'] = timeout
- content = requests.get(path, **request_kwargs).content
- res = BytesIO(content)
- elif os.path.exists(path):
- with open(path, 'rb') as f:
- res = BytesIO(f.read())
- else: # base64_str
- import binascii
- try:
- data = base64.b64decode(path)
- res = BytesIO(data)
- except (ValueError, binascii.Error) as error:
- if len(path) < 200:
- raise ValueError(f'invalid image: "{path}"')
- else:
- raise ValueError(f'invalid image: {error}')
- return res
- def load_file_decorator(func):
- def new_func(path, *args, **kwargs):
- path = load_file(path)
- res = func(path, *args, **kwargs)
- return res
- return new_func
- @load_file_decorator
- def load_image(image: Union['PIL.Image.Image', BytesIO]) -> 'PIL.Image.Image':
- from PIL import Image
- if isinstance(image, BytesIO):
- image = Image.open(image)
- if image.mode != 'RGB':
- image = image.convert('RGB')
- return image
- def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
- load_func: Callable[[Any], _T] = load_image) -> List[_T]:
- res = []
- assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
- for path in path_list:
- if path is None: # ignore None
- continue
- res.append(load_func(path))
- return res
- def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
- if bound:
- start, end = bound[0], bound[1]
- else:
- start, end = -100000, 100000
- start_idx = max(first_idx, round(start * fps))
- end_idx = min(round(end * fps), max_frame)
- seg_size = float(end_idx - start_idx) / num_segments
- frame_indices = np.array(
- [int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
- return frame_indices
- def transform_image(image, input_size=448, max_num=12):
- transform = _build_transform(input_size=input_size)
- images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
- pixel_values = [transform(image) for image in images]
- pixel_values = torch.stack(pixel_values)
- return pixel_values
- @load_file_decorator
- def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
- from decord import VideoReader, cpu
- from PIL import Image
- vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
- max_frame = len(vr) - 1
- fps = float(vr.get_avg_fps())
- images = []
- frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
- for frame_index in frame_indices:
- images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
- return images
- def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
- from PIL import Image, ImageDraw
- from swift.llm.template.template import Template
- image = Image.open(img_dir)
- objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
- Template.normalize_bbox(objects, [image], 'real')
- bbox = objects[0]['bbox']
- draw = ImageDraw.Draw(image)
- draw.rectangle(bbox, outline='red', width=2)
- image.save(output_file)
- @load_file_decorator
- def load_video_cogvlm2(video_io: BytesIO) -> np.ndarray:
- from decord import cpu, VideoReader, bridge
- bridge.set_bridge('torch')
- clip_end_sec = 60
- clip_start_sec = 0
- num_frames = 24
- decord_vr = VideoReader(video_io, ctx=cpu(0))
- duration = len(decord_vr) # duration in terms of frames
- start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
- end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
- clip_end_sec is not None else duration
- frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
- video_data = decord_vr.get_batch(frame_id_list)
- video_data = video_data.permute(3, 0, 1, 2)
- return video_data
- @load_file_decorator
- def load_video_llava(video_io: BytesIO) -> np.ndarray:
- import av
- container = av.open(video_io)
- total_frames = container.streams.video[0].frames
- indices = np.arange(0, total_frames, total_frames / 8).astype(int)
- frames = []
- container.seek(0)
- start_index = indices[0]
- end_index = indices[-1]
- for i, frame in enumerate(container.decode(video=0)):
- if i > end_index:
- break
- if i >= start_index and i in indices:
- frames.append(frame)
- return np.stack([x.to_ndarray(format='rgb24') for x in frames])
- @load_file_decorator
- def load_video_minicpmv_mplug_owl3(video_io: BytesIO, max_num_frames):
- from PIL import Image
- from decord import VideoReader, cpu # pip install decord
- def uniform_sample(_l, _n):
- gap = len(_l) / _n
- idxs = [int(i * gap + gap / 2) for i in range(_n)]
- return [_l[i] for i in idxs]
- vr = VideoReader(video_io, ctx=cpu(0))
- sample_fps = round(vr.get_avg_fps() / 1) # FPS
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
- if len(frame_idx) > max_num_frames:
- frame_idx = uniform_sample(frame_idx, max_num_frames)
- frames = vr.get_batch(frame_idx).asnumpy()
- frames = [Image.fromarray(v.astype('uint8')) for v in frames]
- return frames
- @load_file_decorator
- def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
- import librosa
- return librosa.load(audio_io, sr=sampling_rate)[0]
- def load_video_qwen2(video_path: str):
- from swift.llm.template.template import get_env_args
- import torchvision
- from torchvision import io, transforms
- from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
- VIDEO_MIN_PIXELS, VIDEO_MAX_PIXELS, VIDEO_TOTAL_PIXELS, smart_resize,
- ceil_by_factor, floor_by_factor)
- from torchvision.transforms import InterpolationMode
- if version.parse(torchvision.__version__) >= version.parse('0.19'):
- video_path = load_file(video_path)
- video, _, info = io.read_video(
- video_path,
- pts_unit='sec',
- output_format='TCHW',
- )
- nframes = get_env_args('nframes', int, None)
- fps = get_env_args('fps', int, None)
- size_factor = get_env_args('size_factor', int, FRAME_FACTOR)
- assert not (fps and nframes), 'Only accept either `fps` or `nframes`'
- if nframes is not None:
- nframes = round_by_factor(nframes, size_factor)
- else:
- fps = FPS
- nframes = video.size(0) / info['video_fps'] * fps
- nframes = round_by_factor(nframes, size_factor)
- min_frames = get_env_args('min_frames', int, FPS_MIN_FRAMES)
- max_frames = get_env_args('max_frames', int, FPS_MAX_FRAMES)
- if nframes < min_frames:
- nframes = ceil_by_factor(min_frames, size_factor)
- if nframes > max_frames:
- nframes = floor_by_factor(max_frames, size_factor)
- if not (size_factor <= nframes and nframes <= video.size(0)):
- raise ValueError(f'nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.')
- idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
- height, width = video.shape[2:]
- video = video[idx]
- min_pixels = get_env_args('min_pixels', int, VIDEO_MIN_PIXELS)
- total_pixels = get_env_args('total_pixels', int, VIDEO_TOTAL_PIXELS)
- max_pixels = get_env_args('max_pixels', int, None)
- if max_pixels is None:
- max_pixels = VIDEO_MAX_PIXELS
- max_pixels = max(min(max_pixels, total_pixels / nframes * size_factor), min_pixels * 1.05)
- # resize
- resized_height = get_env_args('resized_height', int, None)
- resized_width = get_env_args('resized_width', int, None)
- if resized_height and resized_width:
- resized_height, resized_width = smart_resize(
- resized_height,
- resized_width,
- factor=size_factor,
- )
- else:
- resized_height, resized_width = smart_resize(
- height,
- width,
- factor=size_factor,
- min_pixels=min_pixels,
- max_pixels=max_pixels,
- )
- video = transforms.functional.resize(
- video,
- [resized_height, resized_width],
- interpolation=InterpolationMode.BICUBIC,
- antialias=True,
- ).float()
- return video
- if __name__ == '__main__':
- # A test main to draw bbox
- draw_plot('man.jpg', [354, 462, 580, 738], 'norm_1000', 'man_bbox.jpg')
|