telemetry.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import re
  2. import sys
  3. from types import TracebackType
  4. from typing import TYPE_CHECKING, ContextManager, Dict, List, Optional, Set, Type
  5. import wandb
  6. from wandb.proto.wandb_telemetry_pb2 import Imports as TelemetryImports
  7. from wandb.proto.wandb_telemetry_pb2 import TelemetryRecord
  8. # avoid cycle, use string type reference
  9. if TYPE_CHECKING:
  10. from .. import wandb_run
  11. _LABEL_TOKEN: str = "@wandbcode{"
  12. class _TelemetryObject:
  13. _run: Optional["wandb_run.Run"]
  14. _obj: TelemetryRecord
  15. def __init__(
  16. self,
  17. run: Optional["wandb_run.Run"] = None,
  18. obj: Optional[TelemetryRecord] = None,
  19. ) -> None:
  20. self._run = run or wandb.run
  21. self._obj = obj or TelemetryRecord()
  22. def __enter__(self) -> TelemetryRecord:
  23. return self._obj
  24. def __exit__(
  25. self,
  26. exctype: Optional[Type[BaseException]],
  27. excinst: Optional[BaseException],
  28. exctb: Optional[TracebackType],
  29. ) -> None:
  30. if not self._run:
  31. return
  32. self._run._telemetry_callback(self._obj)
  33. def context(
  34. run: Optional["wandb_run.Run"] = None, obj: Optional[TelemetryRecord] = None
  35. ) -> ContextManager[TelemetryRecord]:
  36. return _TelemetryObject(run=run, obj=obj)
  37. MATCH_RE = re.compile(r"(?P<code>[a-zA-Z0-9_-]+)[,}](?P<rest>.*)")
  38. def _parse_label_lines(lines: List[str]) -> Dict[str, str]:
  39. seen = False
  40. ret = {}
  41. for line in lines:
  42. idx = line.find(_LABEL_TOKEN)
  43. if idx < 0:
  44. # Stop parsing on first non token line after match
  45. if seen:
  46. break
  47. continue
  48. seen = True
  49. label_str = line[idx + len(_LABEL_TOKEN) :]
  50. # match identifier (first token without key=value syntax (optional)
  51. # Note: Parse is fairly permissive as it does not enforce strict syntax
  52. r = MATCH_RE.match(label_str)
  53. if r:
  54. ret["code"] = r.group("code").replace("-", "_")
  55. label_str = r.group("rest")
  56. # match rest of tokens on one line
  57. tokens = re.findall(
  58. r'([a-zA-Z0-9_]+)\s*=\s*("[a-zA-Z0-9_-]*"|[a-zA-Z0-9_-]*)[,}]', label_str
  59. )
  60. for k, v in tokens:
  61. ret[k] = v.strip('"').replace("-", "_")
  62. return ret
  63. def list_telemetry_imports(only_imported: bool = False) -> Set[str]:
  64. import_telemetry_set = {
  65. desc.name
  66. for desc in TelemetryImports.DESCRIPTOR.fields
  67. if desc.type == desc.TYPE_BOOL
  68. }
  69. if only_imported:
  70. imported_modules_set = set(sys.modules)
  71. return imported_modules_set.intersection(import_telemetry_set)
  72. return import_telemetry_set
  73. __all__ = [
  74. "TelemetryImports",
  75. "TelemetryRecord",
  76. "context",
  77. "list_telemetry_imports",
  78. ]