pipelines.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # TODO:
  15. # 1. Reuse `httpx` client.
  16. # 2. Use `contextvars` to manage MCP context objects.
  17. # 3. Implement structured logging, log stack traces, and log operation timing.
  18. # 4. Report progress for long-running operations.
  19. import abc
  20. import asyncio
  21. import base64
  22. import io
  23. import json
  24. import re
  25. from pathlib import PurePath
  26. from queue import Queue
  27. from threading import Thread
  28. from typing import Any, Callable, Dict, List, NoReturn, Optional, Type, Union
  29. from urllib.parse import urlparse
  30. import httpx
  31. import numpy as np
  32. import puremagic
  33. from fastmcp import Context, FastMCP
  34. from mcp.types import ImageContent, TextContent
  35. from PIL import Image as PILImage
  36. from typing_extensions import Literal, Self, assert_never
  37. try:
  38. from paddleocr import PaddleOCR, PaddleOCRVL, PPStructureV3
  39. LOCAL_OCR_AVAILABLE = True
  40. except ImportError:
  41. LOCAL_OCR_AVAILABLE = False
  42. OutputMode = Literal["simple", "detailed"]
  43. def _is_file_path(s: str) -> bool:
  44. try:
  45. PurePath(s)
  46. return True
  47. except Exception:
  48. return False
  49. def _is_base64(s: str) -> bool:
  50. pattern = r"^[A-Za-z0-9+/]+={0,2}$"
  51. return bool(re.fullmatch(pattern, s))
  52. def _is_url(s: str) -> bool:
  53. if not (s.startswith("http://") or s.startswith("https://")):
  54. return False
  55. result = urlparse(s)
  56. return all([result.scheme, result.netloc]) and result.scheme in ("http", "https")
  57. def _infer_file_type_from_bytes(data: bytes) -> Optional[str]:
  58. mime = puremagic.from_string(data, mime=True)
  59. if mime.startswith("image/"):
  60. return "image"
  61. elif mime == "application/pdf":
  62. return "pdf"
  63. return None
  64. def get_str_with_max_len(obj: object, max_len: int) -> str:
  65. s = str(obj)
  66. if len(s) > max_len:
  67. return s[:max_len] + "..."
  68. else:
  69. return s
  70. class _EngineWrapper:
  71. def __init__(self, engine: Any) -> None:
  72. self._engine = engine
  73. self._queue: Queue = Queue()
  74. self._closed = False
  75. self._loop = asyncio.get_running_loop()
  76. self._thread = Thread(target=self._worker, daemon=False)
  77. self._thread.start()
  78. @property
  79. def engine(self) -> Any:
  80. return self._engine
  81. async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
  82. if self._closed:
  83. raise RuntimeError("Engine wrapper has already been closed")
  84. fut = self._loop.create_future()
  85. self._queue.put((func, args, kwargs, fut))
  86. return await fut
  87. async def close(self) -> None:
  88. if not self._closed:
  89. self._queue.put(None)
  90. await self._loop.run_in_executor(None, self._thread.join)
  91. self._closed = True
  92. def _worker(self) -> None:
  93. while not self._closed:
  94. item = self._queue.get()
  95. if item is None:
  96. break
  97. func, args, kwargs, fut = item
  98. try:
  99. result = func(*args, **kwargs)
  100. self._loop.call_soon_threadsafe(fut.set_result, result)
  101. except Exception as e:
  102. self._loop.call_soon_threadsafe(fut.set_exception, e)
  103. finally:
  104. self._queue.task_done()
  105. class PipelineHandler(abc.ABC):
  106. """Abstract base class for pipeline handlers."""
  107. def __init__(
  108. self,
  109. pipeline: str,
  110. ppocr_source: str,
  111. pipeline_config: Optional[str],
  112. device: Optional[str],
  113. server_url: Optional[str],
  114. aistudio_access_token: Optional[str],
  115. qianfan_api_key: Optional[str],
  116. timeout: Optional[int],
  117. ) -> None:
  118. """Initialize the pipeline handler.
  119. Args:
  120. pipeline: Pipeline name.
  121. ppocr_source: Source of PaddleOCR functionality.
  122. pipeline_config: Path to pipeline configuration.
  123. device: Device to run inference on.
  124. server_url: Base URL for service mode.
  125. aistudio_access_token: AI Studio access token.
  126. qianfan_api_key: Qianfan API key.
  127. timeout: Read timeout in seconds for HTTP requests.
  128. """
  129. self._pipeline = pipeline
  130. if ppocr_source == "local":
  131. self._mode = "local"
  132. elif ppocr_source in ("aistudio", "qianfan", "self_hosted"):
  133. self._mode = "service"
  134. else:
  135. raise ValueError(f"Unknown PaddleOCR source {repr(ppocr_source)}")
  136. self._ppocr_source = ppocr_source
  137. self._pipeline_config = pipeline_config
  138. self._device = device
  139. self._server_url = server_url
  140. self._aistudio_access_token = aistudio_access_token
  141. self._qianfan_api_key = qianfan_api_key
  142. self._timeout = timeout or 60
  143. if self._mode == "local":
  144. if not LOCAL_OCR_AVAILABLE:
  145. raise RuntimeError("PaddleOCR is not locally available")
  146. try:
  147. self._engine = self._create_local_engine()
  148. except Exception as e:
  149. raise RuntimeError(
  150. f"Failed to create PaddleOCR engine: {str(e)}"
  151. ) from e
  152. self._status: Literal["initialized", "started", "stopped"] = "initialized"
  153. async def start(self) -> None:
  154. if self._status == "initialized":
  155. if self._mode == "local":
  156. self._engine_wrapper = _EngineWrapper(self._engine)
  157. self._status = "started"
  158. elif self._status == "started":
  159. pass
  160. elif self._status == "stopped":
  161. raise RuntimeError("Pipeline handler has already been stopped")
  162. else:
  163. assert_never(self._status)
  164. async def stop(self) -> None:
  165. if self._status == "initialized":
  166. raise RuntimeError("Pipeline handler has not been started")
  167. elif self._status == "started":
  168. if self._mode == "local":
  169. await self._engine_wrapper.close()
  170. self._status = "stopped"
  171. elif self._status == "stopped":
  172. pass
  173. else:
  174. assert_never(self._status)
  175. async def __aenter__(self) -> Self:
  176. await self.start()
  177. return self
  178. async def __aexit__(
  179. self,
  180. exc_type: Any,
  181. exc_val: Any,
  182. exc_tb: Any,
  183. ) -> None:
  184. await self.stop()
  185. @abc.abstractmethod
  186. def register_tools(self, mcp: FastMCP) -> None:
  187. """Register tools with the MCP server.
  188. Args:
  189. mcp: The `FastMCP` instance.
  190. """
  191. raise NotImplementedError
  192. @abc.abstractmethod
  193. def _create_local_engine(self) -> Any:
  194. """Create the local OCR engine.
  195. Returns:
  196. The OCR engine instance.
  197. """
  198. raise NotImplementedError
  199. @abc.abstractmethod
  200. def _get_service_endpoint(self) -> str:
  201. """Get the service endpoint.
  202. Returns:
  203. Service endpoint path.
  204. """
  205. raise NotImplementedError
  206. @abc.abstractmethod
  207. def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  208. """Transform keyword arguments for local execution.
  209. Args:
  210. kwargs: Keyword arguments.
  211. Returns:
  212. Transformed keyword arguments.
  213. """
  214. raise NotImplementedError
  215. @abc.abstractmethod
  216. def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  217. """Transform keyword arguments for service execution.
  218. Args:
  219. kwargs: Keyword arguments.
  220. Returns:
  221. Transformed keyword arguments.
  222. """
  223. raise NotImplementedError
  224. @abc.abstractmethod
  225. async def _parse_local_result(
  226. self, local_result: Dict, ctx: Context
  227. ) -> Dict[str, Any]:
  228. """Parse raw result from local engine into a unified format.
  229. Args:
  230. local_result: Raw result from local engine.
  231. ctx: MCP context.
  232. Returns:
  233. Parsed result in unified format.
  234. """
  235. raise NotImplementedError
  236. @abc.abstractmethod
  237. async def _parse_service_result(
  238. self, service_result: Dict[str, Any], ctx: Context
  239. ) -> Dict[str, Any]:
  240. """Parse raw result from the service into a unified format.
  241. Args:
  242. service_result: Raw result from the service.
  243. ctx: MCP context.
  244. Returns:
  245. Parsed result in unified format.
  246. """
  247. raise NotImplementedError
  248. @abc.abstractmethod
  249. async def _log_completion_stats(self, result: Dict[str, Any], ctx: Context) -> None:
  250. """Log statistics after processing completion.
  251. Args:
  252. result: Processing result.
  253. ctx: MCP context.
  254. """
  255. raise NotImplementedError
  256. @abc.abstractmethod
  257. async def _format_output(
  258. self,
  259. result: Dict[str, Any],
  260. detailed: bool,
  261. ctx: Context,
  262. **kwargs: Any,
  263. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  264. """Format output into simple or detailed format.
  265. Args:
  266. result: Processing result.
  267. detailed: Whether to use detailed format.
  268. ctx: MCP context.
  269. **kwargs: Additional arguments.
  270. Returns:
  271. Formatted output in requested format.
  272. """
  273. raise NotImplementedError
  274. async def _predict_with_local_engine(
  275. self, processed_input: Union[str, np.ndarray], ctx: Context, **kwargs: Any
  276. ) -> Dict:
  277. if not hasattr(self, "_engine_wrapper"):
  278. raise RuntimeError("Engine wrapper has not been initialized")
  279. return await self._engine_wrapper.call(
  280. self._engine_wrapper.engine.predict, processed_input, **kwargs
  281. )
  282. class SimpleInferencePipelineHandler(PipelineHandler):
  283. """Base class for simple inference pipeline handlers."""
  284. async def process(
  285. self,
  286. input_data: str,
  287. output_mode: OutputMode,
  288. ctx: Context,
  289. file_type: Optional[str] = None,
  290. infer_kwargs: Optional[Dict[str, Any]] = None,
  291. format_kwargs: Optional[Dict[str, Any]] = None,
  292. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  293. """Process input data through the pipeline.
  294. Args:
  295. input_data: Input data (file path, URL, or Base64).
  296. output_mode: Output mode ("simple" or "detailed").
  297. ctx: MCP context.
  298. file_type: File type for URLs ("image", "pdf", or None for auto-detection).
  299. infer_kwargs: Additional arguments for performing pipeline inference.
  300. format_kwargs: Additional arguments for formatting the output.
  301. Returns:
  302. Processed result in the requested output format.
  303. """
  304. infer_kwargs = infer_kwargs or {}
  305. format_kwargs = format_kwargs or {}
  306. try:
  307. await ctx.info(
  308. f"Starting {self._pipeline} processing (source: {self._ppocr_source})"
  309. )
  310. if self._mode == "local":
  311. processed_input = self._process_input_for_local(input_data, file_type)
  312. infer_kwargs = self._transform_local_kwargs(infer_kwargs)
  313. raw_result = await self._predict_with_local_engine(
  314. processed_input, ctx, **infer_kwargs
  315. )
  316. result = await self._parse_local_result(raw_result, ctx)
  317. else:
  318. processed_input, inferred_file_type = self._process_input_for_service(
  319. input_data, file_type
  320. )
  321. infer_kwargs = self._transform_service_kwargs(infer_kwargs)
  322. raw_result = await self._call_service(
  323. processed_input, inferred_file_type, ctx, **infer_kwargs
  324. )
  325. result = await self._parse_service_result(raw_result, ctx)
  326. await self._log_completion_stats(result, ctx)
  327. return await self._format_output(
  328. result, output_mode == "detailed", ctx, **format_kwargs
  329. )
  330. except Exception as e:
  331. await ctx.error(f"{self._pipeline} processing failed: {str(e)}")
  332. self._handle_error(e, output_mode)
  333. def _process_input_for_local(
  334. self, input_data: str, file_type: Optional[str]
  335. ) -> Union[str, np.ndarray]:
  336. # TODO: Use `file_type` to handle more cases.
  337. if _is_base64(input_data):
  338. if input_data.startswith("data:"):
  339. base64_data = input_data.split(",", 1)[1]
  340. else:
  341. base64_data = input_data
  342. try:
  343. image_bytes = base64.b64decode(base64_data)
  344. file_type = _infer_file_type_from_bytes(image_bytes)
  345. if file_type != "image":
  346. raise ValueError("Currently, only images can be passed via Base64.")
  347. image_pil = PILImage.open(io.BytesIO(image_bytes))
  348. image_arr = np.array(image_pil.convert("RGB"))
  349. return np.ascontiguousarray(image_arr[..., ::-1])
  350. except Exception as e:
  351. raise ValueError(f"Failed to decode Base64 image: {str(e)}") from e
  352. elif _is_file_path(input_data) or _is_url(input_data):
  353. return input_data
  354. else:
  355. raise ValueError("Invalid input data format")
  356. def _process_input_for_service(
  357. self, input_data: str, file_type: Optional[str]
  358. ) -> tuple[str, Optional[str]]:
  359. if _is_url(input_data):
  360. norm_ft = None
  361. if isinstance(file_type, str):
  362. if file_type.lower() in ("None", "none", "null", "unknown", ""):
  363. norm_ft = None
  364. else:
  365. norm_ft = file_type.lower()
  366. return input_data, norm_ft
  367. elif _is_base64(input_data):
  368. try:
  369. if input_data.startswith("data:"):
  370. base64_data = input_data.split(",", 1)[1]
  371. else:
  372. base64_data = input_data
  373. bytes_ = base64.b64decode(base64_data)
  374. file_type_str = _infer_file_type_from_bytes(bytes_)
  375. if file_type_str is None:
  376. raise ValueError(
  377. "Unsupported file type in Base64 data. "
  378. "Only images (JPEG, PNG, etc.) and PDF documents are supported."
  379. )
  380. return input_data, file_type_str
  381. except Exception as e:
  382. raise ValueError(f"Failed to decode Base64 data: {str(e)}") from e
  383. elif _is_file_path(input_data):
  384. try:
  385. with open(input_data, "rb") as f:
  386. bytes_ = f.read()
  387. input_data = base64.b64encode(bytes_).decode("ascii")
  388. file_type_str = _infer_file_type_from_bytes(bytes_)
  389. if file_type_str is None:
  390. raise ValueError(
  391. f"Unsupported file type for '{input_data}'. "
  392. "Only images (JPEG, PNG, etc.) and PDF documents are supported."
  393. )
  394. return input_data, file_type_str
  395. except Exception as e:
  396. raise ValueError(f"Failed to read file: {str(e)}") from e
  397. else:
  398. raise ValueError("Invalid input data format")
  399. async def _call_service(
  400. self,
  401. processed_input: str,
  402. file_type: Optional[str],
  403. ctx: Context,
  404. **kwargs: Any,
  405. ) -> Dict[str, Any]:
  406. if not self._server_url:
  407. raise RuntimeError("Server URL not configured")
  408. endpoint = self._get_service_endpoint()
  409. if endpoint:
  410. endpoint = "/" + endpoint
  411. url = f"{self._server_url.rstrip('/')}{endpoint}"
  412. payload = self._prepare_service_payload(processed_input, file_type, **kwargs)
  413. headers = {"Content-Type": "application/json"}
  414. if self._ppocr_source == "aistudio":
  415. if not self._aistudio_access_token:
  416. raise RuntimeError("Missing AI Studio access token")
  417. headers["Authorization"] = f"token {self._aistudio_access_token}"
  418. elif self._ppocr_source == "qianfan":
  419. if not self._qianfan_api_key:
  420. raise RuntimeError("Missing Qianfan API key")
  421. headers["Authorization"] = f"Bearer {self._qianfan_api_key}"
  422. try:
  423. timeout = httpx.Timeout(
  424. connect=30.0, read=self._timeout, write=30.0, pool=30.0
  425. )
  426. async with httpx.AsyncClient(timeout=timeout) as client:
  427. response = await client.post(url, json=payload, headers=headers)
  428. response.raise_for_status()
  429. return response.json()
  430. except httpx.HTTPError as e:
  431. raise RuntimeError(f"HTTP request failed: {type(e).__name__}: {str(e)}")
  432. except json.JSONDecodeError as e:
  433. raise RuntimeError(f"Invalid service response: {str(e)}")
  434. def _prepare_service_payload(
  435. self, processed_input: str, file_type: Optional[str], **kwargs: Any
  436. ) -> Dict[str, Any]:
  437. payload: Dict[str, Any] = {"file": processed_input, **kwargs}
  438. if file_type == "image":
  439. payload["fileType"] = 1
  440. elif file_type == "pdf":
  441. payload["fileType"] = 0
  442. else:
  443. payload["fileType"] = None
  444. return payload
  445. def _handle_error(self, exc: Exception, output_mode: OutputMode) -> NoReturn:
  446. raise exc
  447. class OCRHandler(SimpleInferencePipelineHandler):
  448. def register_tools(self, mcp: FastMCP) -> None:
  449. @mcp.tool("ocr")
  450. async def _ocr(
  451. input_data: str,
  452. output_mode: OutputMode = "simple",
  453. file_type: Optional[str] = None,
  454. *,
  455. ctx: Context,
  456. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  457. """Extracts text from images and PDFs. Accepts file path, URL, or Base64.
  458. Args:
  459. input_data: The file to process (file path, URL, or Base64 string).
  460. output_mode: The desired output format.
  461. - "simple": (Default) Clean, readable text suitable for most use cases.
  462. - "detailed": A JSON output including text, confidence, and precise bounding box coordinates. Only use this when coordinates are specifically required.
  463. file_type: File type. This parameter is REQUIRED when `input_data` is a URL and should be omitted for other types.
  464. - "image": For image files
  465. - "pdf": For PDF documents
  466. - None: For unknown file types
  467. """
  468. await ctx.info(
  469. f"--- OCR tool received `input_data`: {get_str_with_max_len(input_data, 50)} ---"
  470. )
  471. return await self.process(input_data, output_mode, ctx, file_type)
  472. def _create_local_engine(self) -> Any:
  473. return PaddleOCR(
  474. paddlex_config=self._pipeline_config,
  475. device=self._device,
  476. )
  477. def _get_service_endpoint(self) -> str:
  478. return "ocr"
  479. def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  480. return {
  481. "use_doc_unwarping": False,
  482. "use_doc_orientation_classify": False,
  483. }
  484. def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  485. return {
  486. "useDocUnwarping": False,
  487. "useDocOrientationClassify": False,
  488. }
  489. async def _parse_local_result(self, local_result: Dict, ctx: Context) -> Dict:
  490. clean_texts, confidences, text_lines = [], [], []
  491. for result in local_result:
  492. texts = result["rec_texts"]
  493. scores = result["rec_scores"]
  494. boxes = result["rec_boxes"]
  495. for i, text in enumerate(texts):
  496. if text and text.strip():
  497. conf = scores[i] if i < len(scores) else 0
  498. clean_texts.append(text.strip())
  499. confidences.append(conf)
  500. instance = {
  501. "text": text.strip(),
  502. "confidence": round(conf, 3),
  503. "bbox": boxes[i].tolist(),
  504. }
  505. text_lines.append(instance)
  506. return {
  507. "text": "\n".join(clean_texts),
  508. "confidence": sum(confidences) / len(confidences) if confidences else 0,
  509. "text_lines": text_lines,
  510. }
  511. async def _parse_service_result(self, service_result: Dict, ctx: Context) -> Dict:
  512. result_data = service_result.get("result", service_result)
  513. ocr_results = result_data.get("ocrResults")
  514. all_texts, all_confidences, text_lines = [], [], []
  515. for ocr_result in ocr_results:
  516. pruned = ocr_result["prunedResult"]
  517. texts = pruned["rec_texts"]
  518. scores = pruned["rec_scores"]
  519. boxes = pruned["rec_boxes"]
  520. for i, text in enumerate(texts):
  521. if text and text.strip():
  522. conf = scores[i] if i < len(scores) else 0
  523. all_texts.append(text.strip())
  524. all_confidences.append(conf)
  525. instance = {
  526. "text": text.strip(),
  527. "confidence": round(conf, 3),
  528. "bbox": boxes[i],
  529. }
  530. text_lines.append(instance)
  531. return {
  532. "text": "\n".join(all_texts),
  533. "confidence": (
  534. sum(all_confidences) / len(all_confidences) if all_confidences else 0
  535. ),
  536. "text_lines": text_lines,
  537. }
  538. async def _log_completion_stats(self, result: Dict, ctx: Context) -> None:
  539. text_length = len(result["text"])
  540. text_line_count = len(result["text_lines"])
  541. await ctx.info(
  542. f"OCR completed: {text_length} characters, {text_line_count} text lines"
  543. )
  544. async def _format_output(
  545. self,
  546. result: Dict,
  547. detailed: bool,
  548. ctx: Context,
  549. **kwargs: Any,
  550. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  551. if not result["text"].strip():
  552. return (
  553. "❌ No text detected"
  554. if not detailed
  555. else json.dumps({"error": "No text detected"}, ensure_ascii=False)
  556. )
  557. if detailed:
  558. return json.dumps(result, ensure_ascii=False, indent=2)
  559. else:
  560. confidence = result["confidence"]
  561. text_line_count = len(result["text_lines"])
  562. output = result["text"]
  563. if confidence > 0:
  564. output += f"\n\n📊 Confidence: {(confidence * 100):.1f}% | {text_line_count} text lines"
  565. return output
  566. class _LayoutParsingHandler(SimpleInferencePipelineHandler):
  567. def _get_service_endpoint(self) -> str:
  568. return "layout-parsing" if self._ppocr_source != "qianfan" else "paddleocr"
  569. def _transform_local_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  570. return {
  571. "use_doc_unwarping": False,
  572. "use_doc_orientation_classify": False,
  573. }
  574. def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  575. return {
  576. "useDocUnwarping": False,
  577. "useDocOrientationClassify": False,
  578. }
  579. async def _parse_local_result(self, local_result: Dict, ctx: Context) -> Dict:
  580. markdown_parts = []
  581. all_images_mapping = {}
  582. detailed_results = []
  583. for result in local_result:
  584. markdown = result.markdown
  585. text = markdown["markdown_texts"]
  586. markdown_parts.append(text)
  587. images = markdown["markdown_images"]
  588. processed_images = {}
  589. for img_key, img_data in images.items():
  590. with io.BytesIO() as buffer:
  591. img_data.save(buffer, format="JPEG")
  592. processed_images[img_key] = base64.b64encode(buffer.getvalue())
  593. all_images_mapping.update(processed_images)
  594. detailed_results.append(result)
  595. return {
  596. # TODO: Page concatenation can be done better via `pipeline.concatenate_markdown_pages`
  597. "markdown": "\n".join(markdown_parts),
  598. "pages": len(local_result),
  599. "images_mapping": all_images_mapping,
  600. "detailed_results": detailed_results,
  601. }
  602. async def _parse_service_result(self, service_result: Dict, ctx: Context) -> Dict:
  603. result_data = service_result.get("result", service_result)
  604. layout_results = result_data.get("layoutParsingResults")
  605. if not layout_results:
  606. return {
  607. "markdown": "",
  608. "pages": 0,
  609. "images_mapping": {},
  610. "detailed_results": [],
  611. }
  612. markdown_parts = []
  613. all_images_mapping = {}
  614. detailed_results = []
  615. for res in layout_results:
  616. markdown_parts.append(res["markdown"]["text"])
  617. images = res["markdown"]["images"]
  618. processed_images = {}
  619. for img_key, img_data in images.items():
  620. processed_images[img_key] = await self._process_image_data(
  621. img_data, ctx
  622. )
  623. all_images_mapping.update(processed_images)
  624. detailed_results.append(res["prunedResult"])
  625. return {
  626. "markdown": "\n".join(markdown_parts),
  627. "pages": len(layout_results),
  628. "images_mapping": all_images_mapping,
  629. "detailed_results": detailed_results,
  630. }
  631. async def _process_image_data(self, img_data: str, ctx: Context) -> str:
  632. if _is_url(img_data):
  633. try:
  634. timeout = httpx.Timeout(connect=30.0, read=30.0, write=30.0, pool=30.0)
  635. async with httpx.AsyncClient(timeout=timeout) as client:
  636. response = await client.get(img_data)
  637. response.raise_for_status()
  638. img_bytes = response.content
  639. return base64.b64encode(img_bytes).decode("ascii")
  640. except Exception as e:
  641. await ctx.error(
  642. f"Failed to download image from URL {img_data}: {str(e)}"
  643. )
  644. return img_data
  645. elif _is_base64(img_data):
  646. return img_data
  647. else:
  648. await ctx.error(
  649. f"Unknown image data format: {get_str_with_max_len(img_data, 50)}"
  650. )
  651. return img_data
  652. async def _log_completion_stats(self, result: Dict, ctx: Context) -> None:
  653. page_count = result["pages"]
  654. await ctx.info(f"Layout parsing completed: {page_count} pages")
  655. async def _format_output(
  656. self,
  657. result: Dict,
  658. detailed: bool,
  659. ctx: Context,
  660. **kwargs: Any,
  661. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  662. if not result["markdown"].strip():
  663. return (
  664. "❌ No document content detected"
  665. if not detailed
  666. else json.dumps({"error": "No content detected"}, ensure_ascii=False)
  667. )
  668. markdown_text = result["markdown"]
  669. images_mapping = result.get("images_mapping", {})
  670. if kwargs.get("return_images"):
  671. content_list = self._parse_markdown_with_images(
  672. markdown_text, images_mapping
  673. )
  674. else:
  675. content_list = [TextContent(type="text", text=markdown_text)]
  676. if detailed:
  677. if "detailed_results" in result and result["detailed_results"]:
  678. for detailed_result in result["detailed_results"]:
  679. content_list.append(
  680. TextContent(
  681. type="text",
  682. text=json.dumps(
  683. detailed_result,
  684. ensure_ascii=False,
  685. indent=2,
  686. default=str,
  687. ),
  688. )
  689. )
  690. return content_list
  691. def _parse_markdown_with_images(
  692. self, markdown_text: str, images_mapping: Dict[str, str]
  693. ) -> List[Union[TextContent, ImageContent]]:
  694. """Parse markdown text and return mixed list of text and images."""
  695. if not images_mapping:
  696. return [TextContent(type="text", text=markdown_text)]
  697. content_list = []
  698. img_pattern = r'<img[^>]+src="([^"]+)"[^>]*>'
  699. last_pos = 0
  700. for match in re.finditer(img_pattern, markdown_text):
  701. text_before = markdown_text[last_pos : match.start()]
  702. if text_before.strip():
  703. content_list.append(TextContent(type="text", text=text_before))
  704. img_src = match.group(1)
  705. if img_src in images_mapping:
  706. content_list.append(
  707. ImageContent(
  708. type="image",
  709. data=images_mapping[img_src],
  710. mimeType="image/jpeg",
  711. )
  712. )
  713. last_pos = match.end()
  714. remaining_text = markdown_text[last_pos:]
  715. if remaining_text.strip():
  716. content_list.append(TextContent(type="text", text=remaining_text))
  717. return content_list or [TextContent(type="text", text=markdown_text)]
  718. class PPStructureV3Handler(_LayoutParsingHandler):
  719. def register_tools(self, mcp: FastMCP) -> None:
  720. @mcp.tool("pp_structurev3")
  721. async def _pp_structurev3(
  722. input_data: str,
  723. output_mode: OutputMode = "simple",
  724. file_type: Optional[str] = None,
  725. return_images: bool = True,
  726. *,
  727. ctx: Context,
  728. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  729. """Extracts structured markdown from complex documents (images/PDFs), including tables, formulas, etc. Accepts file path, URL, or Base64.
  730. Args:
  731. input_data: The file to process (file path, URL, or Base64 string).
  732. output_mode: The desired output format.
  733. - "simple": (Default) Clean, readable markdown with embedded images. Best for most use cases.
  734. - "detailed": JSON data about document structure, plus markdown. Only use this when coordinates are specifically required.
  735. file_type: File type. This parameter is REQUIRED when `input_data` is a URL and should be omitted for other types.
  736. - "image": For image files
  737. - "pdf": For PDF documents
  738. - None: For unknown file types
  739. return_images: Whether to return the images extracted from the document.
  740. """
  741. return await self.process(
  742. input_data,
  743. output_mode,
  744. ctx,
  745. file_type,
  746. format_kwargs={"return_images": return_images},
  747. )
  748. def _create_local_engine(self) -> Any:
  749. return PPStructureV3(
  750. paddlex_config=self._pipeline_config,
  751. device=self._device,
  752. )
  753. def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  754. kwargs = super()._transform_service_kwargs(kwargs)
  755. if self._ppocr_source == "qianfan":
  756. kwargs["model"] = "pp-structurev3"
  757. return kwargs
  758. class PaddleOCRVLHandler(_LayoutParsingHandler):
  759. def register_tools(self, mcp: FastMCP) -> None:
  760. @mcp.tool("paddleocr_vl")
  761. async def _paddleocr_vl(
  762. input_data: str,
  763. output_mode: OutputMode = "simple",
  764. file_type: Optional[str] = None,
  765. return_images: bool = True,
  766. *,
  767. ctx: Context,
  768. ) -> Union[str, List[Union[TextContent, ImageContent]]]:
  769. """Extracts structured markdown from complex documents (images/PDFs) using a VLM-based approach. The extracted elements include tables, formulas, etc. Accepts file path, URL, or Base64.
  770. Args:
  771. input_data: The file to process (file path, URL, or Base64 string).
  772. output_mode: The desired output format.
  773. - "simple": (Default) Clean, readable markdown with embedded images. Best for most use cases.
  774. - "detailed": JSON data about document structure, plus markdown. Only use this when coordinates are specifically required.
  775. file_type: File type. This parameter is REQUIRED when `input_data` is a URL and should be omitted for other types.
  776. - "image": For image files
  777. - "pdf": For PDF documents
  778. - None: For unknown file types
  779. return_images: Whether to return the images extracted from the document.
  780. """
  781. return await self.process(
  782. input_data,
  783. output_mode,
  784. ctx,
  785. file_type,
  786. format_kwargs={"return_images": return_images},
  787. )
  788. def _create_local_engine(self) -> Any:
  789. return PaddleOCRVL(
  790. paddlex_config=self._pipeline_config,
  791. device=self._device,
  792. )
  793. def _transform_service_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  794. kwargs = super()._transform_service_kwargs(kwargs)
  795. if self._ppocr_source == "qianfan":
  796. kwargs["model"] = "paddleocr-vl-0.9b"
  797. return kwargs
  798. _PIPELINE_HANDLERS: Dict[str, Type[PipelineHandler]] = {
  799. "OCR": OCRHandler,
  800. "PP-StructureV3": PPStructureV3Handler,
  801. "PaddleOCR-VL": PaddleOCRVLHandler,
  802. }
  803. def create_pipeline_handler(
  804. pipeline: str, /, *args: Any, **kwargs: Any
  805. ) -> PipelineHandler:
  806. if pipeline in _PIPELINE_HANDLERS:
  807. cls = _PIPELINE_HANDLERS[pipeline]
  808. return cls(pipeline, *args, **kwargs)
  809. else:
  810. raise ValueError(f"Unknown pipeline {repr(pipeline)}")