_utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from __future__ import annotations
  2. from typing import Any, Collection, Final, Optional, Protocol, TypedDict
  3. from pydantic import Field
  4. from typing_extensions import Annotated, Self, Unpack
  5. from wandb._pydantic import GQLId, GQLInput, computed_field, model_validator, to_json
  6. from ._filters import MongoLikeFilter
  7. from ._generated import (
  8. CreateFilterTriggerInput,
  9. QueueJobActionInput,
  10. TriggeredActionConfig,
  11. UpdateFilterTriggerInput,
  12. )
  13. from ._validators import parse_input_action
  14. from .actions import (
  15. ActionType,
  16. DoNothing,
  17. InputAction,
  18. SavedAction,
  19. SendNotification,
  20. SendWebhook,
  21. )
  22. from .automations import Automation, NewAutomation
  23. from .events import EventType, InputEvent, RunMetricFilter, _WrappedSavedEventFilter
  24. from .scopes import AutomationScope, ScopeType
  25. INVALID_INPUT_EVENTS: Final[Collection[EventType]] = (EventType.UPDATE_ARTIFACT_ALIAS,)
  26. """Event types that should NOT be allowed as new values on new or edited automations.
  27. While we forbid new/edited automations from assigning these event types,
  28. they're defined so that we can still parse existing automations that may use them.
  29. """
  30. INVALID_INPUT_ACTIONS: Final[Collection[ActionType]] = (ActionType.QUEUE_JOB,)
  31. """Action types that should NOT be allowed as new values on new or edited automations.
  32. While we forbid new/edited automations from assigning these action types,
  33. they're defined so that we can still parse existing automations that may use them.
  34. """
  35. ALWAYS_SUPPORTED_EVENTS: Final[Collection[EventType]] = frozenset(
  36. {
  37. EventType.CREATE_ARTIFACT,
  38. EventType.LINK_ARTIFACT,
  39. EventType.ADD_ARTIFACT_ALIAS,
  40. }
  41. )
  42. """Event types that should be supported by all current, non-EOL server versions."""
  43. ALWAYS_SUPPORTED_ACTIONS: Final[Collection[ActionType]] = frozenset(
  44. {
  45. ActionType.NOTIFICATION,
  46. ActionType.GENERIC_WEBHOOK,
  47. }
  48. )
  49. """Action types that should be supported by all current, non-EOL server versions."""
  50. class HasId(Protocol):
  51. id: str
  52. def extract_id(obj: HasId | str) -> str:
  53. return obj.id if hasattr(obj, "id") else obj
  54. # ---------------------------------------------------------------------------
  55. ACTION_CONFIG_KEYS: dict[ActionType, str] = {
  56. ActionType.NOTIFICATION: "notification_action_input",
  57. ActionType.GENERIC_WEBHOOK: "generic_webhook_action_input",
  58. ActionType.NO_OP: "no_op_action_input",
  59. ActionType.QUEUE_JOB: "queue_job_action_input",
  60. }
  61. class InputActionConfig(TriggeredActionConfig):
  62. """Prepares action configuration data for saving an automation."""
  63. # NOTE: `QueueJobActionInput` for defining a Launch job is deprecated,
  64. # so while it's allowed here to update EXISTING mutations, we don't
  65. # currently expose it through the public API to create NEW automations.
  66. queue_job_action_input: Optional[QueueJobActionInput] = None
  67. notification_action_input: Optional[SendNotification] = None
  68. generic_webhook_action_input: Optional[SendWebhook] = None
  69. no_op_action_input: Optional[DoNothing] = None
  70. def prepare_action_config_input(obj: SavedAction | InputAction) -> dict[str, Any]:
  71. """Nests the action input under the correct key for `TriggeredActionConfig`.
  72. This is necessary to conform to the schemas for:
  73. - `CreateFilterTriggerInput`
  74. - `UpdateFilterTriggerInput`
  75. """
  76. # Delegate to inner validators to convert SavedAction -> InputAction types, if needed.
  77. obj = parse_input_action(obj)
  78. return InputActionConfig(**{ACTION_CONFIG_KEYS[obj.action_type]: obj}).model_dump()
  79. def prepare_event_filter_input(
  80. obj: _WrappedSavedEventFilter | MongoLikeFilter | RunMetricFilter,
  81. ) -> str:
  82. """Unnests (if needed) and serializes an `EventFilter` input to JSON.
  83. This is necessary to conform to the schemas for:
  84. - `CreateFilterTriggerInput`
  85. - `UpdateFilterTriggerInput`
  86. """
  87. # Input event filters are nested one level deeper than saved event filters.
  88. # Note that this is NOT the case for run/run metric filters.
  89. #
  90. # Yes, this is confusing. It's also necessary to conform to under-the-hood
  91. # schemas and logic in the backend.
  92. if isinstance(obj, _WrappedSavedEventFilter):
  93. return to_json(obj.filter)
  94. return to_json(obj)
  95. class WriteAutomationsKwargs(TypedDict, total=False):
  96. """Keyword arguments that can be passed to create or update an automation."""
  97. name: str
  98. description: str
  99. enabled: bool
  100. scope: AutomationScope
  101. event: InputEvent
  102. action: InputAction
  103. class ValidatedCreateInput(GQLInput, extra="forbid", frozen=True):
  104. """Validated automation parameters, prepared for creating a new automation.
  105. Note: Users should never need to instantiate this class directly.
  106. """
  107. name: str
  108. description: Optional[str] = None
  109. enabled: bool = True
  110. # ------------------------------------------------------------------------------
  111. # Set on instantiation, but used to derive other fields and deliberately
  112. # EXCLUDED from the final GraphQL request vars
  113. event: Annotated[InputEvent, Field(exclude=True)]
  114. action: Annotated[InputAction, Field(exclude=True)]
  115. # ------------------------------------------------------------------------------
  116. # Derived fields to match the input schemas
  117. @computed_field
  118. def scope_type(self) -> ScopeType:
  119. return self.event.scope.scope_type
  120. @computed_field
  121. def scope_id(self) -> GQLId:
  122. return self.event.scope.id
  123. @computed_field
  124. def triggering_event_type(self) -> EventType:
  125. return self.event.event_type
  126. @computed_field
  127. def event_filter(self) -> str:
  128. return prepare_event_filter_input(self.event.filter)
  129. @computed_field
  130. def triggered_action_type(self) -> ActionType:
  131. return self.action.action_type
  132. @computed_field
  133. def triggered_action_config(self) -> dict[str, Any]:
  134. return prepare_action_config_input(self.action)
  135. # ------------------------------------------------------------------------------
  136. # Custom validation
  137. @model_validator(mode="after")
  138. def _forbid_legacy_event_types(self) -> Self:
  139. if (type_ := self.event.event_type) in INVALID_INPUT_EVENTS:
  140. raise ValueError(f"{type_!r} events cannot be assigned to automations.")
  141. return self
  142. @model_validator(mode="after")
  143. def _forbid_legacy_action_types(self) -> Self:
  144. if (type_ := self.action.action_type) in INVALID_INPUT_ACTIONS:
  145. raise ValueError(f"{type_!r} actions cannot be assigned to automations.")
  146. return self
  147. def prepare_to_create(
  148. obj: NewAutomation | None = None,
  149. /,
  150. **kwargs: Unpack[WriteAutomationsKwargs],
  151. ) -> CreateFilterTriggerInput:
  152. """Prepares the payload to create an automation in a GraphQL request."""
  153. # Validate all input variables, and prepare as expected by the GraphQL request.
  154. # - if an object is provided, override its fields with any keyword args
  155. # - otherwise, instantiate from the keyword args
  156. obj_dict = {**obj.model_dump(), **kwargs} if obj else kwargs
  157. vobj = ValidatedCreateInput(**obj_dict)
  158. return CreateFilterTriggerInput.model_validate(vobj)
  159. def prepare_to_update(
  160. obj: Automation | None = None,
  161. /,
  162. **kwargs: Unpack[WriteAutomationsKwargs],
  163. ) -> UpdateFilterTriggerInput:
  164. """Prepares the payload to update an automation in a GraphQL request."""
  165. # Validate all values:
  166. # - if an object is provided, override its fields with any keyword args
  167. # - otherwise, instantiate from the keyword args
  168. vobj = Automation(**{**dict(obj or {}), **kwargs})
  169. return UpdateFilterTriggerInput(
  170. id=vobj.id,
  171. name=vobj.name,
  172. description=vobj.description,
  173. enabled=vobj.enabled,
  174. scope_type=vobj.scope.scope_type,
  175. scope_id=vobj.scope.id,
  176. triggering_event_type=vobj.event.event_type,
  177. event_filter=prepare_event_filter_input(vobj.event.filter),
  178. triggered_action_type=vobj.action.action_type,
  179. triggered_action_config=prepare_action_config_input(vobj.action),
  180. )