api.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import abc
  9. import time
  10. from collections import namedtuple
  11. from functools import wraps
  12. from typing import Optional
  13. from typing_extensions import deprecated
  14. __all__ = [
  15. "MetricsConfig",
  16. "MetricHandler",
  17. "ConsoleMetricHandler",
  18. "NullMetricHandler",
  19. "MetricStream",
  20. "configure",
  21. "getStream",
  22. "prof",
  23. "profile",
  24. "put_metric",
  25. "publish_metric",
  26. "get_elapsed_time_ms",
  27. "MetricData",
  28. ]
  29. MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
  30. class MetricsConfig:
  31. __slots__ = ["params"]
  32. def __init__(self, params: Optional[dict[str, str]] = None):
  33. self.params = params
  34. if self.params is None:
  35. self.params = {}
  36. class MetricHandler(abc.ABC):
  37. @abc.abstractmethod
  38. def emit(self, metric_data: MetricData):
  39. pass
  40. class ConsoleMetricHandler(MetricHandler):
  41. def emit(self, metric_data: MetricData):
  42. print(
  43. f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
  44. )
  45. class NullMetricHandler(MetricHandler):
  46. def emit(self, metric_data: MetricData):
  47. pass
  48. class MetricStream:
  49. def __init__(self, group_name: str, handler: MetricHandler):
  50. self.group_name = group_name
  51. self.handler = handler
  52. def add_value(self, metric_name: str, metric_value: int):
  53. self.handler.emit(
  54. MetricData(time.time(), self.group_name, metric_name, metric_value)
  55. )
  56. _metrics_map: dict[str, MetricHandler] = {}
  57. _default_metrics_handler: MetricHandler = NullMetricHandler()
  58. # pyre-fixme[9]: group has type `str`; used as `None`.
  59. def configure(handler: MetricHandler, group: Optional[str] = None):
  60. if group is None:
  61. global _default_metrics_handler
  62. # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
  63. # as `MetricHandler`.
  64. _default_metrics_handler = handler
  65. else:
  66. _metrics_map[group] = handler
  67. def getStream(group: str):
  68. if group in _metrics_map:
  69. handler = _metrics_map[group]
  70. else:
  71. handler = _default_metrics_handler
  72. return MetricStream(group, handler)
  73. def _get_metric_name(fn):
  74. qualname = fn.__qualname__
  75. split = qualname.split(".")
  76. if len(split) == 1:
  77. module = fn.__module__
  78. if module:
  79. return module.split(".")[-1] + "." + split[0]
  80. else:
  81. return split[0]
  82. else:
  83. return qualname
  84. def prof(fn=None, group: str = "torchelastic"):
  85. r"""
  86. @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
  87. The metric name defaults to the qualified name (``class_name.def_name``) of the function.
  88. If the function does not belong to a class, it uses the leaf module name instead.
  89. Usage
  90. ::
  91. @metrics.prof
  92. def x():
  93. pass
  94. @metrics.prof(group="agent")
  95. def y():
  96. pass
  97. """
  98. def wrap(f):
  99. @wraps(f)
  100. def wrapper(*args, **kwargs):
  101. key = _get_metric_name(f)
  102. try:
  103. start = time.time()
  104. result = f(*args, **kwargs)
  105. put_metric(f"{key}.success", 1, group)
  106. except Exception:
  107. put_metric(f"{key}.failure", 1, group)
  108. raise
  109. finally:
  110. put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
  111. return result
  112. return wrapper
  113. if fn:
  114. return wrap(fn)
  115. else:
  116. return wrap
  117. @deprecated("Deprecated, use `@prof` instead", category=FutureWarning)
  118. def profile(group=None):
  119. """
  120. @profile decorator adds latency and success/failure metrics to any given function.
  121. Usage
  122. ::
  123. @metrics.profile("my_metric_group")
  124. def some_function(<arguments>):
  125. """
  126. def wrap(func):
  127. @wraps(func)
  128. def wrapper(*args, **kwargs):
  129. try:
  130. start_time = time.time()
  131. result = func(*args, **kwargs)
  132. publish_metric(group, f"{func.__name__}.success", 1)
  133. except Exception:
  134. publish_metric(group, f"{func.__name__}.failure", 1)
  135. raise
  136. finally:
  137. publish_metric(
  138. group,
  139. f"{func.__name__}.duration.ms",
  140. get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
  141. )
  142. return result
  143. return wrapper
  144. return wrap
  145. def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
  146. """
  147. Publish a metric data point.
  148. Usage
  149. ::
  150. put_metric("metric_name", 1)
  151. put_metric("metric_name", 1, "metric_group_name")
  152. """
  153. getStream(metric_group).add_value(metric_name, metric_value)
  154. @deprecated(
  155. "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
  156. category=FutureWarning,
  157. )
  158. def publish_metric(metric_group: str, metric_name: str, metric_value: int):
  159. metric_stream = getStream(metric_group)
  160. metric_stream.add_value(metric_name, metric_value)
  161. def get_elapsed_time_ms(start_time_in_seconds: float):
  162. """Return the elapsed time in millis from the given start time."""
  163. end_time = time.time()
  164. return int((end_time - start_time_in_seconds) * 1000)