cli.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import asyncio
  2. import os
  3. import signal
  4. import traceback
  5. from typing import Optional
  6. import typer
  7. from rich import print
  8. from ._cli_hacks import _async_prompt, _patch_anyio_open_process
  9. from .agent import Agent
  10. from .utils import _load_agent_config
  11. app = typer.Typer(
  12. rich_markup_mode="rich",
  13. help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
  14. )
  15. run_cli = typer.Typer(
  16. name="run",
  17. help="Run the Agent in the CLI",
  18. invoke_without_command=True,
  19. )
  20. app.add_typer(run_cli, name="run")
  21. async def run_agent(
  22. agent_path: Optional[str],
  23. ) -> None:
  24. """
  25. Tiny Agent loop.
  26. Args:
  27. agent_path (`str`, *optional*):
  28. Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
  29. """
  30. _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
  31. config, prompt = _load_agent_config(agent_path)
  32. inputs = config.get("inputs", [])
  33. servers = config.get("servers", [])
  34. abort_event = asyncio.Event()
  35. exit_event = asyncio.Event()
  36. first_sigint = True
  37. loop = asyncio.get_running_loop()
  38. original_sigint_handler = signal.getsignal(signal.SIGINT)
  39. def _sigint_handler() -> None:
  40. nonlocal first_sigint
  41. if first_sigint:
  42. first_sigint = False
  43. abort_event.set()
  44. print("\n[red]Interrupted. Press Ctrl+C again to quit.[/red]", flush=True)
  45. return
  46. print("\n[red]Exiting...[/red]", flush=True)
  47. exit_event.set()
  48. try:
  49. sigint_registered_in_loop = False
  50. try:
  51. loop.add_signal_handler(signal.SIGINT, _sigint_handler)
  52. sigint_registered_in_loop = True
  53. except (AttributeError, NotImplementedError):
  54. # Windows (or any loop that doesn't support it) : fall back to sync
  55. signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
  56. # Handle inputs (i.e. env variables injection)
  57. resolved_inputs: dict[str, str] = {}
  58. if len(inputs) > 0:
  59. print(
  60. "[bold blue]Some initial inputs are required by the agent. "
  61. "Please provide a value or leave empty to load from env.[/bold blue]"
  62. )
  63. for input_item in inputs:
  64. input_id = input_item["id"]
  65. description = input_item["description"]
  66. env_special_value = f"${{input:{input_id}}}"
  67. # Check if the input is used by any server or as an apiKey
  68. input_usages = set()
  69. for server in servers:
  70. # Check stdio's "env" and http/sse's "headers" mappings
  71. env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
  72. for key, value in env_or_headers.items():
  73. if env_special_value in value:
  74. input_usages.add(key)
  75. raw_api_key = config.get("apiKey")
  76. if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
  77. input_usages.add("apiKey")
  78. if not input_usages:
  79. print(
  80. f"[yellow]Input '{input_id}' defined in config but not used by any server or as an API key."
  81. " Skipping.[/yellow]"
  82. )
  83. continue
  84. # Prompt user for input
  85. env_variable_key = input_id.replace("-", "_").upper()
  86. print(
  87. f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).",
  88. end=" ",
  89. )
  90. user_input = (await _async_prompt(exit_event=exit_event)).strip()
  91. if exit_event.is_set():
  92. return
  93. # Fallback to environment variable when user left blank
  94. final_value = user_input
  95. if not final_value:
  96. final_value = os.getenv(env_variable_key, "")
  97. if final_value:
  98. print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
  99. else:
  100. print(
  101. f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
  102. )
  103. resolved_inputs[input_id] = final_value
  104. # Inject resolved value (can be empty) into stdio's env or http/sse's headers
  105. for server in servers:
  106. env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
  107. for key, value in env_or_headers.items():
  108. if env_special_value in value:
  109. env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
  110. print()
  111. raw_api_key = config.get("apiKey")
  112. if isinstance(raw_api_key, str):
  113. substituted_api_key = raw_api_key
  114. for input_id, val in resolved_inputs.items():
  115. substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
  116. config["apiKey"] = substituted_api_key
  117. # Main agent loop
  118. async with Agent(
  119. provider=config.get("provider"), # type: ignore[arg-type]
  120. model=config.get("model"),
  121. base_url=config.get("endpointUrl"), # type: ignore[arg-type]
  122. api_key=config.get("apiKey"),
  123. servers=servers, # type: ignore[arg-type]
  124. prompt=prompt,
  125. ) as agent:
  126. await agent.load_tools()
  127. print(f"[bold blue]Agent loaded with {len(agent.available_tools)} tools:[/bold blue]")
  128. for t in agent.available_tools:
  129. print(f"[blue] • {t.function.name}[/blue]")
  130. while True:
  131. abort_event.clear()
  132. # Check if we should exit
  133. if exit_event.is_set():
  134. return
  135. try:
  136. user_input = await _async_prompt(exit_event=exit_event)
  137. first_sigint = True
  138. except EOFError:
  139. print("\n[red]EOF received, exiting.[/red]", flush=True)
  140. break
  141. except KeyboardInterrupt:
  142. if not first_sigint and abort_event.is_set():
  143. continue
  144. else:
  145. print("\n[red]Keyboard interrupt during input processing.[/red]", flush=True)
  146. break
  147. try:
  148. async for chunk in agent.run(user_input, abort_event=abort_event):
  149. if abort_event.is_set() and not first_sigint:
  150. break
  151. if exit_event.is_set():
  152. return
  153. if hasattr(chunk, "choices"):
  154. delta = chunk.choices[0].delta
  155. if delta.content:
  156. print(delta.content, end="", flush=True)
  157. if delta.tool_calls:
  158. for call in delta.tool_calls:
  159. if call.id:
  160. print(f"<Tool {call.id}>", end="")
  161. if call.function.name:
  162. print(f"{call.function.name}", end=" ")
  163. if call.function.arguments:
  164. print(f"{call.function.arguments}", end="")
  165. else:
  166. print(
  167. f"\n\n[green]Tool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}[/green]\n",
  168. flush=True,
  169. )
  170. print()
  171. except Exception as e:
  172. tb_str = traceback.format_exc()
  173. print(f"\n[bold red]Error during agent run: {e}\n{tb_str}[/bold red]", flush=True)
  174. first_sigint = True # Allow graceful interrupt for the next command
  175. except Exception as e:
  176. tb_str = traceback.format_exc()
  177. print(f"\n[bold red]An unexpected error occurred: {e}\n{tb_str}[/bold red]", flush=True)
  178. raise e
  179. finally:
  180. if sigint_registered_in_loop:
  181. try:
  182. loop.remove_signal_handler(signal.SIGINT)
  183. except (AttributeError, NotImplementedError):
  184. pass
  185. else:
  186. signal.signal(signal.SIGINT, original_sigint_handler)
  187. @run_cli.callback()
  188. def run(
  189. path: Optional[str] = typer.Argument(
  190. None,
  191. help=(
  192. "Path to a local folder containing an agent.json file or a built-in agent "
  193. "stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
  194. "(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
  195. ),
  196. show_default=False,
  197. ),
  198. ):
  199. try:
  200. asyncio.run(run_agent(path))
  201. except KeyboardInterrupt:
  202. print("\n[red]Application terminated by KeyboardInterrupt.[/red]", flush=True)
  203. raise typer.Exit(code=130)
  204. except Exception as e:
  205. print(f"\n[bold red]An unexpected error occurred: {e}[/bold red]", flush=True)
  206. raise e
  207. if __name__ == "__main__":
  208. app()