strawberry.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. import functools
  2. import hashlib
  3. import warnings
  4. from inspect import isawaitable
  5. import sentry_sdk
  6. from sentry_sdk.consts import OP
  7. from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
  8. from sentry_sdk.integrations.logging import ignore_logger
  9. from sentry_sdk.scope import should_send_default_pii
  10. from sentry_sdk.tracing import TransactionSource
  11. from sentry_sdk.utils import (
  12. capture_internal_exceptions,
  13. ensure_integration_enabled,
  14. event_from_exception,
  15. logger,
  16. package_version,
  17. _get_installed_modules,
  18. )
  19. try:
  20. from functools import cached_property
  21. except ImportError:
  22. # The strawberry integration requires Python 3.8+. functools.cached_property
  23. # was added in 3.8, so this check is technically not needed, but since this
  24. # is an auto-enabling integration, we might get to executing this import in
  25. # lower Python versions, so we need to deal with it.
  26. raise DidNotEnable("strawberry-graphql integration requires Python 3.8 or newer")
  27. try:
  28. from strawberry import Schema
  29. from strawberry.extensions import SchemaExtension
  30. from strawberry.extensions.tracing.utils import (
  31. should_skip_tracing as strawberry_should_skip_tracing,
  32. )
  33. from strawberry.http import async_base_view, sync_base_view
  34. except ImportError:
  35. raise DidNotEnable("strawberry-graphql is not installed")
  36. try:
  37. from strawberry.extensions.tracing import (
  38. SentryTracingExtension as StrawberrySentryAsyncExtension,
  39. SentryTracingExtensionSync as StrawberrySentrySyncExtension,
  40. )
  41. except ImportError:
  42. StrawberrySentryAsyncExtension = None
  43. StrawberrySentrySyncExtension = None
  44. from typing import TYPE_CHECKING
  45. if TYPE_CHECKING:
  46. from typing import Any, Callable, Generator, List, Optional
  47. from graphql import GraphQLError, GraphQLResolveInfo
  48. from strawberry.http import GraphQLHTTPResponse
  49. from strawberry.types import ExecutionContext
  50. from sentry_sdk._types import Event, EventProcessor
  51. ignore_logger("strawberry.execution")
  52. class StrawberryIntegration(Integration):
  53. identifier = "strawberry"
  54. origin = f"auto.graphql.{identifier}"
  55. def __init__(self, async_execution: "Optional[bool]" = None) -> None:
  56. if async_execution not in (None, False, True):
  57. raise ValueError(
  58. 'Invalid value for async_execution: "{}" (must be bool)'.format(
  59. async_execution
  60. )
  61. )
  62. self.async_execution = async_execution
  63. @staticmethod
  64. def setup_once() -> None:
  65. version = package_version("strawberry-graphql")
  66. _check_minimum_version(StrawberryIntegration, version, "strawberry-graphql")
  67. _patch_schema_init()
  68. _patch_views()
  69. def _patch_schema_init() -> None:
  70. old_schema_init = Schema.__init__
  71. @functools.wraps(old_schema_init)
  72. def _sentry_patched_schema_init(
  73. self: "Schema", *args: "Any", **kwargs: "Any"
  74. ) -> None:
  75. integration = sentry_sdk.get_client().get_integration(StrawberryIntegration)
  76. if integration is None:
  77. return old_schema_init(self, *args, **kwargs)
  78. extensions = kwargs.get("extensions") or []
  79. should_use_async_extension: "Optional[bool]" = None
  80. if integration.async_execution is not None:
  81. should_use_async_extension = integration.async_execution
  82. else:
  83. # try to figure it out ourselves
  84. should_use_async_extension = _guess_if_using_async(extensions)
  85. if should_use_async_extension is None:
  86. warnings.warn(
  87. "Assuming strawberry is running sync. If not, initialize the integration as StrawberryIntegration(async_execution=True).",
  88. stacklevel=2,
  89. )
  90. should_use_async_extension = False
  91. # remove the built in strawberry sentry extension, if present
  92. extensions = [
  93. extension
  94. for extension in extensions
  95. if extension
  96. not in (StrawberrySentryAsyncExtension, StrawberrySentrySyncExtension)
  97. ]
  98. # add our extension
  99. extensions.append(
  100. SentryAsyncExtension if should_use_async_extension else SentrySyncExtension
  101. )
  102. kwargs["extensions"] = extensions
  103. return old_schema_init(self, *args, **kwargs)
  104. Schema.__init__ = _sentry_patched_schema_init # type: ignore[method-assign]
  105. class SentryAsyncExtension(SchemaExtension):
  106. def __init__(
  107. self: "Any",
  108. *,
  109. execution_context: "Optional[ExecutionContext]" = None,
  110. ) -> None:
  111. if execution_context:
  112. self.execution_context = execution_context
  113. @cached_property
  114. def _resource_name(self) -> str:
  115. query_hash = self.hash_query(self.execution_context.query) # type: ignore
  116. if self.execution_context.operation_name:
  117. return "{}:{}".format(self.execution_context.operation_name, query_hash)
  118. return query_hash
  119. def hash_query(self, query: str) -> str:
  120. return hashlib.md5(query.encode("utf-8")).hexdigest()
  121. def on_operation(self) -> "Generator[None, None, None]":
  122. self._operation_name = self.execution_context.operation_name
  123. operation_type = "query"
  124. op = OP.GRAPHQL_QUERY
  125. if self.execution_context.query is None:
  126. self.execution_context.query = ""
  127. if self.execution_context.query.strip().startswith("mutation"):
  128. operation_type = "mutation"
  129. op = OP.GRAPHQL_MUTATION
  130. elif self.execution_context.query.strip().startswith("subscription"):
  131. operation_type = "subscription"
  132. op = OP.GRAPHQL_SUBSCRIPTION
  133. description = operation_type
  134. if self._operation_name:
  135. description += " {}".format(self._operation_name)
  136. sentry_sdk.add_breadcrumb(
  137. category="graphql.operation",
  138. data={
  139. "operation_name": self._operation_name,
  140. "operation_type": operation_type,
  141. },
  142. )
  143. scope = sentry_sdk.get_isolation_scope()
  144. event_processor = _make_request_event_processor(self.execution_context)
  145. scope.add_event_processor(event_processor)
  146. span = sentry_sdk.get_current_span()
  147. if span:
  148. self.graphql_span = span.start_child(
  149. op=op,
  150. name=description,
  151. origin=StrawberryIntegration.origin,
  152. )
  153. else:
  154. self.graphql_span = sentry_sdk.start_span(
  155. op=op,
  156. name=description,
  157. origin=StrawberryIntegration.origin,
  158. )
  159. self.graphql_span.set_data("graphql.operation.type", operation_type)
  160. self.graphql_span.set_data("graphql.operation.name", self._operation_name)
  161. self.graphql_span.set_data("graphql.document", self.execution_context.query)
  162. self.graphql_span.set_data("graphql.resource_name", self._resource_name)
  163. yield
  164. transaction = self.graphql_span.containing_transaction
  165. if transaction and self.execution_context.operation_name:
  166. transaction.name = self.execution_context.operation_name
  167. transaction.source = TransactionSource.COMPONENT
  168. transaction.op = op
  169. self.graphql_span.finish()
  170. def on_validate(self) -> "Generator[None, None, None]":
  171. self.validation_span = self.graphql_span.start_child(
  172. op=OP.GRAPHQL_VALIDATE,
  173. name="validation",
  174. origin=StrawberryIntegration.origin,
  175. )
  176. yield
  177. self.validation_span.finish()
  178. def on_parse(self) -> "Generator[None, None, None]":
  179. self.parsing_span = self.graphql_span.start_child(
  180. op=OP.GRAPHQL_PARSE,
  181. name="parsing",
  182. origin=StrawberryIntegration.origin,
  183. )
  184. yield
  185. self.parsing_span.finish()
  186. def should_skip_tracing(
  187. self,
  188. _next: "Callable[[Any, GraphQLResolveInfo, Any, Any], Any]",
  189. info: "GraphQLResolveInfo",
  190. ) -> bool:
  191. return strawberry_should_skip_tracing(_next, info)
  192. async def _resolve(
  193. self,
  194. _next: "Callable[[Any, GraphQLResolveInfo, Any, Any], Any]",
  195. root: "Any",
  196. info: "GraphQLResolveInfo",
  197. *args: str,
  198. **kwargs: "Any",
  199. ) -> "Any":
  200. result = _next(root, info, *args, **kwargs)
  201. if isawaitable(result):
  202. result = await result
  203. return result
  204. async def resolve(
  205. self,
  206. _next: "Callable[[Any, GraphQLResolveInfo, Any, Any], Any]",
  207. root: "Any",
  208. info: "GraphQLResolveInfo",
  209. *args: str,
  210. **kwargs: "Any",
  211. ) -> "Any":
  212. if self.should_skip_tracing(_next, info):
  213. return await self._resolve(_next, root, info, *args, **kwargs)
  214. field_path = "{}.{}".format(info.parent_type, info.field_name)
  215. with self.graphql_span.start_child(
  216. op=OP.GRAPHQL_RESOLVE,
  217. name="resolving {}".format(field_path),
  218. origin=StrawberryIntegration.origin,
  219. ) as span:
  220. span.set_data("graphql.field_name", info.field_name)
  221. span.set_data("graphql.parent_type", info.parent_type.name)
  222. span.set_data("graphql.field_path", field_path)
  223. span.set_data("graphql.path", ".".join(map(str, info.path.as_list())))
  224. return await self._resolve(_next, root, info, *args, **kwargs)
  225. class SentrySyncExtension(SentryAsyncExtension):
  226. def resolve(
  227. self,
  228. _next: "Callable[[Any, Any, Any, Any], Any]",
  229. root: "Any",
  230. info: "GraphQLResolveInfo",
  231. *args: str,
  232. **kwargs: "Any",
  233. ) -> "Any":
  234. if self.should_skip_tracing(_next, info):
  235. return _next(root, info, *args, **kwargs)
  236. field_path = "{}.{}".format(info.parent_type, info.field_name)
  237. with self.graphql_span.start_child(
  238. op=OP.GRAPHQL_RESOLVE,
  239. name="resolving {}".format(field_path),
  240. origin=StrawberryIntegration.origin,
  241. ) as span:
  242. span.set_data("graphql.field_name", info.field_name)
  243. span.set_data("graphql.parent_type", info.parent_type.name)
  244. span.set_data("graphql.field_path", field_path)
  245. span.set_data("graphql.path", ".".join(map(str, info.path.as_list())))
  246. return _next(root, info, *args, **kwargs)
  247. def _patch_views() -> None:
  248. old_async_view_handle_errors = async_base_view.AsyncBaseHTTPView._handle_errors
  249. old_sync_view_handle_errors = sync_base_view.SyncBaseHTTPView._handle_errors
  250. def _sentry_patched_async_view_handle_errors(
  251. self: "Any", errors: "List[GraphQLError]", response_data: "GraphQLHTTPResponse"
  252. ) -> None:
  253. old_async_view_handle_errors(self, errors, response_data)
  254. _sentry_patched_handle_errors(self, errors, response_data)
  255. def _sentry_patched_sync_view_handle_errors(
  256. self: "Any", errors: "List[GraphQLError]", response_data: "GraphQLHTTPResponse"
  257. ) -> None:
  258. old_sync_view_handle_errors(self, errors, response_data)
  259. _sentry_patched_handle_errors(self, errors, response_data)
  260. @ensure_integration_enabled(StrawberryIntegration)
  261. def _sentry_patched_handle_errors(
  262. self: "Any", errors: "List[GraphQLError]", response_data: "GraphQLHTTPResponse"
  263. ) -> None:
  264. if not errors:
  265. return
  266. scope = sentry_sdk.get_isolation_scope()
  267. event_processor = _make_response_event_processor(response_data)
  268. scope.add_event_processor(event_processor)
  269. with capture_internal_exceptions():
  270. for error in errors:
  271. event, hint = event_from_exception(
  272. error,
  273. client_options=sentry_sdk.get_client().options,
  274. mechanism={
  275. "type": StrawberryIntegration.identifier,
  276. "handled": False,
  277. },
  278. )
  279. sentry_sdk.capture_event(event, hint=hint)
  280. async_base_view.AsyncBaseHTTPView._handle_errors = ( # type: ignore[method-assign]
  281. _sentry_patched_async_view_handle_errors
  282. )
  283. sync_base_view.SyncBaseHTTPView._handle_errors = ( # type: ignore[method-assign]
  284. _sentry_patched_sync_view_handle_errors
  285. )
  286. def _make_request_event_processor(
  287. execution_context: "ExecutionContext",
  288. ) -> "EventProcessor":
  289. def inner(event: "Event", hint: "dict[str, Any]") -> "Event":
  290. with capture_internal_exceptions():
  291. if should_send_default_pii():
  292. request_data = event.setdefault("request", {})
  293. request_data["api_target"] = "graphql"
  294. if not request_data.get("data"):
  295. data: "dict[str, Any]" = {"query": execution_context.query}
  296. if execution_context.variables:
  297. data["variables"] = execution_context.variables
  298. if execution_context.operation_name:
  299. data["operationName"] = execution_context.operation_name
  300. request_data["data"] = data
  301. else:
  302. try:
  303. del event["request"]["data"]
  304. except (KeyError, TypeError):
  305. pass
  306. return event
  307. return inner
  308. def _make_response_event_processor(
  309. response_data: "GraphQLHTTPResponse",
  310. ) -> "EventProcessor":
  311. def inner(event: "Event", hint: "dict[str, Any]") -> "Event":
  312. with capture_internal_exceptions():
  313. if should_send_default_pii():
  314. contexts = event.setdefault("contexts", {})
  315. contexts["response"] = {"data": response_data}
  316. return event
  317. return inner
  318. def _guess_if_using_async(extensions: "List[SchemaExtension]") -> "Optional[bool]":
  319. if StrawberrySentryAsyncExtension in extensions:
  320. return True
  321. elif StrawberrySentrySyncExtension in extensions:
  322. return False
  323. return None