interface.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086
  1. from __future__ import annotations
  2. import abc
  3. import gzip
  4. import logging
  5. import time
  6. from pathlib import Path
  7. from secrets import token_hex
  8. from typing import TYPE_CHECKING, Any, Iterable
  9. from wandb import termwarn
  10. from wandb.proto import wandb_internal_pb2 as pb
  11. from wandb.proto import wandb_telemetry_pb2 as tpb
  12. from wandb.sdk.lib import json_util as json
  13. from wandb.sdk.lib.filesystem import FilesDict, PolicyName
  14. from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle
  15. from wandb.util import (
  16. WandBJSONEncoderOld,
  17. get_h5_typename,
  18. json_dumps_safer,
  19. json_dumps_safer_history,
  20. json_friendly,
  21. json_friendly_val,
  22. maybe_compress_summary,
  23. )
  24. from ..data_types.utils import history_dict_to_json, val_to_json
  25. from . import summary_record as sr
  26. MANIFEST_FILE_SIZE_THRESHOLD = 100_000
  27. if TYPE_CHECKING:
  28. from wandb.sdk.artifacts.artifact import Artifact
  29. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  30. from ..wandb_run import Run
  31. logger = logging.getLogger("wandb")
  32. def file_policy_to_enum(policy: PolicyName) -> pb.FilesItem.PolicyType.V:
  33. if policy == "now":
  34. enum = pb.FilesItem.PolicyType.NOW
  35. elif policy == "end":
  36. enum = pb.FilesItem.PolicyType.END
  37. elif policy == "live":
  38. enum = pb.FilesItem.PolicyType.LIVE
  39. return enum
  40. def file_enum_to_policy(enum: pb.FilesItem.PolicyType.V) -> PolicyName:
  41. if enum == pb.FilesItem.PolicyType.NOW:
  42. policy: PolicyName = "now"
  43. elif enum == pb.FilesItem.PolicyType.END:
  44. policy = "end"
  45. elif enum == pb.FilesItem.PolicyType.LIVE:
  46. policy = "live"
  47. return policy
  48. class InterfaceBase(abc.ABC):
  49. """Methods for sending run messages (Records) to the service.
  50. None of the methods may be called from an asyncio context other than
  51. deliver_async() or those with a `nowait=True` argument.
  52. """
  53. _drop: bool
  54. def __init__(self) -> None:
  55. self._drop = False
  56. @abc.abstractmethod
  57. async def deliver_async(
  58. self,
  59. record: pb.Record,
  60. ) -> MailboxHandle[pb.Result]:
  61. """Send a record and create a handle to wait for the response.
  62. The synchronous publish and deliver methods on this class cannot be
  63. called in the asyncio thread because they block. Instead of having
  64. an async copy of every method, this is a general method for sending
  65. any kind of record in the asyncio thread.
  66. Args:
  67. record: The record to send. This method takes ownership of the
  68. record and it must not be used afterward.
  69. Returns:
  70. A handle to wait for a response to the record.
  71. """
  72. def publish_header(self) -> None:
  73. header = pb.HeaderRecord()
  74. self._publish_header(header)
  75. @abc.abstractmethod
  76. def _publish_header(self, header: pb.HeaderRecord) -> None:
  77. raise NotImplementedError
  78. def deliver_status(self) -> MailboxHandle[pb.Result]:
  79. return self._deliver_status(pb.StatusRequest())
  80. @abc.abstractmethod
  81. def _deliver_status(
  82. self,
  83. status: pb.StatusRequest,
  84. ) -> MailboxHandle[pb.Result]:
  85. raise NotImplementedError
  86. def _make_config(
  87. self,
  88. data: dict | None = None,
  89. key: tuple[str, ...] | str | None = None,
  90. val: Any | None = None,
  91. obj: pb.ConfigRecord | None = None,
  92. ) -> pb.ConfigRecord:
  93. config = obj or pb.ConfigRecord()
  94. if data:
  95. for k, v in data.items():
  96. update = config.update.add()
  97. update.key = k
  98. update.value_json = json_dumps_safer(json_friendly(v)[0])
  99. if key:
  100. update = config.update.add()
  101. if isinstance(key, tuple):
  102. for k in key:
  103. update.nested_key.append(k)
  104. else:
  105. update.key = key
  106. update.value_json = json_dumps_safer(json_friendly(val)[0])
  107. return config
  108. def _make_run(self, run: Run) -> pb.RunRecord: # noqa: C901
  109. proto_run = pb.RunRecord()
  110. if run._settings.entity is not None:
  111. proto_run.entity = run._settings.entity
  112. if run._settings.project is not None:
  113. proto_run.project = run._settings.project
  114. if run._settings.run_group is not None:
  115. proto_run.run_group = run._settings.run_group
  116. if run._settings.run_job_type is not None:
  117. proto_run.job_type = run._settings.run_job_type
  118. if run._settings.run_id is not None:
  119. proto_run.run_id = run._settings.run_id
  120. if run._settings.run_name is not None:
  121. proto_run.display_name = run._settings.run_name
  122. if run._settings.run_notes is not None:
  123. proto_run.notes = run._settings.run_notes
  124. if run._settings.run_tags is not None:
  125. proto_run.tags.extend(run._settings.run_tags)
  126. if run._start_time is not None:
  127. proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
  128. if run._starting_step is not None:
  129. proto_run.starting_step = run._starting_step
  130. if run._settings.git_remote_url is not None:
  131. proto_run.git.remote_url = run._settings.git_remote_url
  132. if run._settings.git_commit is not None:
  133. proto_run.git.commit = run._settings.git_commit
  134. if run._settings.sweep_id is not None:
  135. proto_run.sweep_id = run._settings.sweep_id
  136. if run._settings.host:
  137. proto_run.host = run._settings.host
  138. if run._settings.resumed:
  139. proto_run.resumed = run._settings.resumed
  140. if run._settings.fork_from:
  141. run_moment = run._settings.fork_from
  142. proto_run.branch_point.run = run_moment.run
  143. proto_run.branch_point.metric = run_moment.metric
  144. proto_run.branch_point.value = run_moment.value
  145. if run._settings.resume_from:
  146. run_moment = run._settings.resume_from
  147. proto_run.branch_point.run = run_moment.run
  148. proto_run.branch_point.metric = run_moment.metric
  149. proto_run.branch_point.value = run_moment.value
  150. if run._forked:
  151. proto_run.forked = run._forked
  152. if run._config is not None:
  153. config_dict = run._config._as_dict() # type: ignore
  154. self._make_config(data=config_dict, obj=proto_run.config)
  155. if run._telemetry_obj:
  156. proto_run.telemetry.MergeFrom(run._telemetry_obj)
  157. if run._start_runtime:
  158. proto_run.runtime = run._start_runtime
  159. return proto_run
  160. def publish_run(self, run: Run) -> None:
  161. run_record = self._make_run(run)
  162. self._publish_run(run_record)
  163. @abc.abstractmethod
  164. def _publish_run(self, run: pb.RunRecord) -> None:
  165. raise NotImplementedError
  166. def publish_cancel(self, cancel_slot: str) -> None:
  167. cancel = pb.CancelRequest(cancel_slot=cancel_slot)
  168. self._publish_cancel(cancel)
  169. @abc.abstractmethod
  170. def _publish_cancel(self, cancel: pb.CancelRequest) -> None:
  171. raise NotImplementedError
  172. def publish_config(
  173. self,
  174. data: dict | None = None,
  175. key: tuple[str, ...] | str | None = None,
  176. val: Any | None = None,
  177. ) -> None:
  178. cfg = self._make_config(data=data, key=key, val=val)
  179. self._publish_config(cfg)
  180. @abc.abstractmethod
  181. def _publish_config(self, cfg: pb.ConfigRecord) -> None:
  182. raise NotImplementedError
  183. @abc.abstractmethod
  184. def _publish_metric(self, metric: pb.MetricRecord) -> None:
  185. raise NotImplementedError
  186. def _make_summary_from_dict(self, summary_dict: dict) -> pb.SummaryRecord:
  187. summary = pb.SummaryRecord()
  188. for k, v in summary_dict.items():
  189. update = summary.update.add()
  190. update.key = k
  191. update.value_json = json.dumps(v)
  192. return summary
  193. def _summary_encode(
  194. self,
  195. value: Any,
  196. path_from_root: str,
  197. run: Run,
  198. ) -> dict:
  199. """Normalize, compress, and encode sub-objects for backend storage.
  200. value: Object to encode.
  201. path_from_root: `str` dot separated string from the top-level summary to the
  202. current `value`.
  203. Returns:
  204. A new tree of dict's with large objects replaced with dictionaries
  205. with "_type" entries that say which type the original data was.
  206. """
  207. # Constructs a new `dict` tree in `json_value` that discards and/or
  208. # encodes objects that aren't JSON serializable.
  209. if isinstance(value, dict):
  210. json_value = {}
  211. for key, value in value.items(): # noqa: B020
  212. json_value[key] = self._summary_encode(
  213. value,
  214. path_from_root + "." + key,
  215. run=run,
  216. )
  217. return json_value
  218. else:
  219. friendly_value, converted = json_friendly(
  220. val_to_json(run, path_from_root, value, namespace="summary")
  221. )
  222. json_value, compressed = maybe_compress_summary(
  223. friendly_value, get_h5_typename(value)
  224. )
  225. if compressed:
  226. # TODO(jhr): impleement me
  227. pass
  228. # self.write_h5(path_from_root, friendly_value)
  229. return json_value
  230. def _make_summary(
  231. self,
  232. summary_record: sr.SummaryRecord,
  233. run: Run,
  234. ) -> pb.SummaryRecord:
  235. pb_summary_record = pb.SummaryRecord()
  236. for item in summary_record.update:
  237. pb_summary_item = pb_summary_record.update.add()
  238. key_length = len(item.key)
  239. assert key_length > 0
  240. if key_length > 1:
  241. pb_summary_item.nested_key.extend(item.key)
  242. else:
  243. pb_summary_item.key = item.key[0]
  244. path_from_root = ".".join(item.key)
  245. json_value = self._summary_encode(
  246. item.value,
  247. path_from_root,
  248. run=run,
  249. )
  250. json_value, _ = json_friendly(json_value) # type: ignore
  251. pb_summary_item.value_json = json.dumps(
  252. json_value,
  253. cls=WandBJSONEncoderOld,
  254. )
  255. for item in summary_record.remove:
  256. pb_summary_item = pb_summary_record.remove.add()
  257. key_length = len(item.key)
  258. assert key_length > 0
  259. if key_length > 1:
  260. pb_summary_item.nested_key.extend(item.key)
  261. else:
  262. pb_summary_item.key = item.key[0]
  263. return pb_summary_record
  264. def publish_summary(
  265. self,
  266. run: Run,
  267. summary_record: sr.SummaryRecord,
  268. ) -> None:
  269. pb_summary_record = self._make_summary(summary_record, run=run)
  270. self._publish_summary(pb_summary_record)
  271. @abc.abstractmethod
  272. def _publish_summary(self, summary: pb.SummaryRecord) -> None:
  273. raise NotImplementedError
  274. def _make_files(self, files_dict: FilesDict) -> pb.FilesRecord:
  275. files = pb.FilesRecord()
  276. for path, policy in files_dict["files"]:
  277. f = files.files.add()
  278. f.path = path
  279. f.policy = file_policy_to_enum(policy)
  280. return files
  281. def publish_files(self, files_dict: FilesDict) -> None:
  282. files = self._make_files(files_dict)
  283. self._publish_files(files)
  284. @abc.abstractmethod
  285. def _publish_files(self, files: pb.FilesRecord) -> None:
  286. raise NotImplementedError
  287. def publish_python_packages(self, working_set) -> None:
  288. python_packages = pb.PythonPackagesRequest()
  289. for pkg in working_set:
  290. python_packages.package.add(name=pkg.key, version=pkg.version)
  291. self._publish_python_packages(python_packages)
  292. @abc.abstractmethod
  293. def _publish_python_packages(
  294. self, python_packages: pb.PythonPackagesRequest
  295. ) -> None:
  296. raise NotImplementedError
  297. def _make_artifact(self, artifact: Artifact) -> pb.ArtifactRecord:
  298. proto_artifact = pb.ArtifactRecord()
  299. proto_artifact.type = artifact.type
  300. proto_artifact.name = artifact.name
  301. proto_artifact.client_id = artifact._client_id
  302. proto_artifact.sequence_client_id = artifact._sequence_client_id
  303. proto_artifact.digest = artifact.digest
  304. if artifact.distributed_id:
  305. proto_artifact.distributed_id = artifact.distributed_id
  306. if artifact.description:
  307. proto_artifact.description = artifact.description
  308. if artifact.metadata:
  309. proto_artifact.metadata = json.dumps(json_friendly_val(artifact.metadata))
  310. if artifact._base_id:
  311. proto_artifact.base_id = artifact._base_id
  312. ttl_duration_input = artifact._ttl_duration_seconds_to_gql()
  313. if ttl_duration_input:
  314. proto_artifact.ttl_duration_seconds = ttl_duration_input
  315. proto_artifact.incremental_beta1 = artifact.incremental
  316. self._make_artifact_manifest(artifact.manifest, obj=proto_artifact.manifest)
  317. return proto_artifact
  318. def _make_artifact_manifest(
  319. self,
  320. artifact_manifest: ArtifactManifest,
  321. obj: pb.ArtifactManifest | None = None,
  322. ) -> pb.ArtifactManifest:
  323. proto_manifest = obj or pb.ArtifactManifest()
  324. proto_manifest.version = artifact_manifest.version()
  325. proto_manifest.storage_policy = artifact_manifest.storage_policy.name()
  326. # Very large manifests need to be written to file to avoid protobuf size limits.
  327. if len(artifact_manifest) > MANIFEST_FILE_SIZE_THRESHOLD:
  328. path = self._write_artifact_manifest_file(artifact_manifest)
  329. proto_manifest.manifest_file_path = path
  330. return proto_manifest
  331. # Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
  332. # NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
  333. # The creation logic is in artifacts/_factories.py make_storage_policy
  334. for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
  335. cfg = proto_manifest.storage_policy_config.add()
  336. cfg.key = k
  337. # TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
  338. cfg.value_json = json.dumps(v)
  339. for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
  340. proto_entry = proto_manifest.contents.add()
  341. proto_entry.path = entry.path
  342. proto_entry.digest = entry.digest
  343. if entry.size:
  344. proto_entry.size = entry.size
  345. if entry.birth_artifact_id:
  346. proto_entry.birth_artifact_id = entry.birth_artifact_id
  347. if entry.ref:
  348. proto_entry.ref = entry.ref
  349. if entry.local_path:
  350. proto_entry.local_path = entry.local_path
  351. proto_entry.skip_cache = entry.skip_cache
  352. for k, v in entry.extra.items():
  353. proto_extra = proto_entry.extra.add()
  354. proto_extra.key = k
  355. proto_extra.value_json = json.dumps(v)
  356. return proto_manifest
  357. def _write_artifact_manifest_file(self, manifest: ArtifactManifest) -> str:
  358. from wandb.sdk.artifacts.staging import get_staging_dir
  359. manifest_dir = Path(get_staging_dir()) / "artifact_manifests"
  360. manifest_dir.mkdir(parents=True, exist_ok=True)
  361. # It would be simpler to use `manifest.to_json()`, but that gets very slow for
  362. # large manifests since it encodes the whole thing as a single JSON object.
  363. filename = f"{time.time()}_{token_hex(8)}.manifest_contents.jl.gz"
  364. manifest_file_path = manifest_dir / filename
  365. with gzip.open(manifest_file_path, mode="wt", compresslevel=1) as f:
  366. for entry in manifest.entries.values():
  367. f.write(f"{json.dumps(entry.to_json())}\n")
  368. return str(manifest_file_path)
  369. def deliver_link_artifact(
  370. self,
  371. artifact: Artifact,
  372. portfolio_name: str,
  373. aliases: Iterable[str],
  374. entity: str | None = None,
  375. project: str | None = None,
  376. organization: str | None = None,
  377. ) -> MailboxHandle[pb.Result]:
  378. link_artifact = pb.LinkArtifactRequest()
  379. if artifact.is_draft():
  380. link_artifact.client_id = artifact._client_id
  381. else:
  382. link_artifact.server_id = artifact.id if artifact.id else ""
  383. link_artifact.portfolio_name = portfolio_name
  384. link_artifact.portfolio_entity = entity or ""
  385. link_artifact.portfolio_organization = organization or ""
  386. link_artifact.portfolio_project = project or ""
  387. link_artifact.portfolio_aliases.extend(aliases)
  388. return self._deliver_link_artifact(link_artifact)
  389. @abc.abstractmethod
  390. def _deliver_link_artifact(
  391. self, link_artifact: pb.LinkArtifactRequest
  392. ) -> MailboxHandle[pb.Result]:
  393. raise NotImplementedError
  394. @staticmethod
  395. def _make_partial_source_str(
  396. source: Any, job_info: dict[str, Any], metadata: dict[str, Any]
  397. ) -> str:
  398. """Construct use_artifact.partial.source_info.source as str."""
  399. source_type = job_info.get("source_type", "").strip()
  400. if source_type == "artifact":
  401. info_source = job_info.get("source", {})
  402. source.artifact.artifact = info_source.get("artifact", "")
  403. source.artifact.entrypoint.extend(info_source.get("entrypoint", []))
  404. source.artifact.notebook = info_source.get("notebook", False)
  405. build_context = info_source.get("build_context")
  406. if build_context:
  407. source.artifact.build_context = build_context
  408. dockerfile = info_source.get("dockerfile")
  409. if dockerfile:
  410. source.artifact.dockerfile = dockerfile
  411. elif source_type == "repo":
  412. source.git.git_info.remote = metadata.get("git", {}).get("remote", "")
  413. source.git.git_info.commit = metadata.get("git", {}).get("commit", "")
  414. source.git.entrypoint.extend(metadata.get("entrypoint", []))
  415. source.git.notebook = metadata.get("notebook", False)
  416. build_context = metadata.get("build_context")
  417. if build_context:
  418. source.git.build_context = build_context
  419. dockerfile = metadata.get("dockerfile")
  420. if dockerfile:
  421. source.git.dockerfile = dockerfile
  422. elif source_type == "image":
  423. source.image.image = metadata.get("docker", "")
  424. else:
  425. raise ValueError("Invalid source type")
  426. source_str: str = source.SerializeToString()
  427. return source_str
  428. def _make_proto_use_artifact(
  429. self,
  430. use_artifact: pb.UseArtifactRecord,
  431. job_name: str,
  432. job_info: dict[str, Any],
  433. metadata: dict[str, Any],
  434. ) -> pb.UseArtifactRecord:
  435. use_artifact.partial.job_name = job_name
  436. use_artifact.partial.source_info._version = job_info.get("_version", "")
  437. use_artifact.partial.source_info.source_type = job_info.get("source_type", "")
  438. use_artifact.partial.source_info.runtime = job_info.get("runtime", "")
  439. src_str = self._make_partial_source_str(
  440. source=use_artifact.partial.source_info.source,
  441. job_info=job_info,
  442. metadata=metadata,
  443. )
  444. use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type]
  445. return use_artifact
  446. def publish_use_artifact(
  447. self,
  448. artifact: Artifact,
  449. ) -> None:
  450. assert artifact.id is not None, "Artifact must have an id"
  451. use_artifact = pb.UseArtifactRecord(
  452. id=artifact.id,
  453. type=artifact.type,
  454. name=artifact.name,
  455. )
  456. # TODO(gst): move to internal process
  457. if "_partial" in artifact.metadata:
  458. # Download source info from logged partial job artifact
  459. job_info = {}
  460. try:
  461. path = artifact.get_entry("wandb-job.json").download()
  462. with open(path) as f:
  463. job_info = json.load(f)
  464. except Exception as e:
  465. logger.warning(
  466. f"Failed to download partial job info from artifact {artifact}, : {e}"
  467. )
  468. termwarn(
  469. f"Failed to download partial job info from artifact {artifact}, : {e}"
  470. )
  471. return
  472. try:
  473. use_artifact = self._make_proto_use_artifact(
  474. use_artifact=use_artifact,
  475. job_name=artifact.name,
  476. job_info=job_info,
  477. metadata=artifact.metadata,
  478. )
  479. except Exception as e:
  480. logger.warning(f"Failed to construct use artifact proto: {e}")
  481. termwarn(f"Failed to construct use artifact proto: {e}")
  482. return
  483. self._publish_use_artifact(use_artifact)
  484. @abc.abstractmethod
  485. def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None:
  486. raise NotImplementedError
  487. def deliver_artifact(
  488. self,
  489. run: Run,
  490. artifact: Artifact,
  491. aliases: Iterable[str],
  492. tags: Iterable[str] | None = None,
  493. history_step: int | None = None,
  494. is_user_created: bool = False,
  495. use_after_commit: bool = False,
  496. finalize: bool = True,
  497. ) -> MailboxHandle[pb.Result]:
  498. from wandb.sdk.artifacts.staging import get_staging_dir
  499. proto_run = self._make_run(run)
  500. proto_artifact = self._make_artifact(artifact)
  501. proto_artifact.run_id = proto_run.run_id
  502. proto_artifact.project = proto_run.project
  503. proto_artifact.entity = proto_run.entity
  504. proto_artifact.user_created = is_user_created
  505. proto_artifact.use_after_commit = use_after_commit
  506. proto_artifact.finalize = finalize
  507. proto_artifact.aliases.extend(aliases or [])
  508. proto_artifact.tags.extend(tags or [])
  509. log_artifact = pb.LogArtifactRequest()
  510. log_artifact.artifact.CopyFrom(proto_artifact)
  511. if history_step is not None:
  512. log_artifact.history_step = history_step
  513. log_artifact.staging_dir = get_staging_dir()
  514. resp = self._deliver_artifact(log_artifact)
  515. return resp
  516. @abc.abstractmethod
  517. def _deliver_artifact(
  518. self,
  519. log_artifact: pb.LogArtifactRequest,
  520. ) -> MailboxHandle[pb.Result]:
  521. raise NotImplementedError
  522. def deliver_download_artifact(
  523. self,
  524. artifact_id: str,
  525. download_root: str,
  526. allow_missing_references: bool,
  527. skip_cache: bool,
  528. path_prefix: str | None,
  529. ) -> MailboxHandle[pb.Result]:
  530. download_artifact = pb.DownloadArtifactRequest()
  531. download_artifact.artifact_id = artifact_id
  532. download_artifact.download_root = download_root
  533. download_artifact.allow_missing_references = allow_missing_references
  534. download_artifact.skip_cache = skip_cache
  535. download_artifact.path_prefix = path_prefix or ""
  536. resp = self._deliver_download_artifact(download_artifact)
  537. return resp
  538. @abc.abstractmethod
  539. def _deliver_download_artifact(
  540. self, download_artifact: pb.DownloadArtifactRequest
  541. ) -> MailboxHandle[pb.Result]:
  542. raise NotImplementedError
  543. def publish_artifact(
  544. self,
  545. run: Run,
  546. artifact: Artifact,
  547. aliases: Iterable[str],
  548. tags: Iterable[str] | None = None,
  549. is_user_created: bool = False,
  550. use_after_commit: bool = False,
  551. finalize: bool = True,
  552. ) -> None:
  553. proto_run = self._make_run(run)
  554. proto_artifact = self._make_artifact(artifact)
  555. proto_artifact.run_id = proto_run.run_id
  556. proto_artifact.project = proto_run.project
  557. proto_artifact.entity = proto_run.entity
  558. proto_artifact.user_created = is_user_created
  559. proto_artifact.use_after_commit = use_after_commit
  560. proto_artifact.finalize = finalize
  561. proto_artifact.aliases.extend(aliases or [])
  562. proto_artifact.tags.extend(tags or [])
  563. self._publish_artifact(proto_artifact)
  564. @abc.abstractmethod
  565. def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
  566. raise NotImplementedError
  567. def publish_tbdata(self, log_dir: str, save: bool, root_logdir: str = "") -> None:
  568. tbrecord = pb.TBRecord()
  569. tbrecord.log_dir = log_dir
  570. tbrecord.save = save
  571. tbrecord.root_dir = root_logdir
  572. self._publish_tbdata(tbrecord)
  573. @abc.abstractmethod
  574. def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None:
  575. raise NotImplementedError
  576. @abc.abstractmethod
  577. def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
  578. raise NotImplementedError
  579. def publish_environment(self, environment: pb.EnvironmentRecord) -> None:
  580. self._publish_environment(environment)
  581. @abc.abstractmethod
  582. def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
  583. raise NotImplementedError
  584. def publish_partial_history(
  585. self,
  586. run: Run,
  587. data: dict,
  588. user_step: int,
  589. step: int | None = None,
  590. flush: bool | None = None,
  591. publish_step: bool = True,
  592. ) -> None:
  593. data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True)
  594. data.pop("_step", None)
  595. # add timestamp to the history request, if not already present
  596. # the timestamp might come from the tensorboard log logic
  597. if "_timestamp" not in data:
  598. data["_timestamp"] = time.time()
  599. partial_history = pb.PartialHistoryRequest()
  600. for k, v in data.items():
  601. item = partial_history.item.add()
  602. item.key = k
  603. item.value_json = json_dumps_safer_history(v)
  604. if publish_step and step is not None:
  605. partial_history.step.num = step
  606. if flush is not None:
  607. partial_history.action.flush = flush
  608. self._publish_partial_history(partial_history)
  609. @abc.abstractmethod
  610. def _publish_partial_history(self, history: pb.PartialHistoryRequest) -> None:
  611. raise NotImplementedError
  612. def publish_history(
  613. self,
  614. run: Run,
  615. data: dict,
  616. step: int | None = None,
  617. publish_step: bool = True,
  618. ) -> None:
  619. data = history_dict_to_json(run, data, step=step)
  620. history = pb.HistoryRecord()
  621. if publish_step:
  622. assert step is not None
  623. history.step.num = step
  624. data.pop("_step", None)
  625. for k, v in data.items():
  626. item = history.item.add()
  627. item.key = k
  628. item.value_json = json_dumps_safer_history(v)
  629. self._publish_history(history)
  630. @abc.abstractmethod
  631. def _publish_history(self, history: pb.HistoryRecord) -> None:
  632. raise NotImplementedError
  633. def publish_preempting(self) -> None:
  634. preempt_rec = pb.RunPreemptingRecord()
  635. self._publish_preempting(preempt_rec)
  636. @abc.abstractmethod
  637. def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None:
  638. raise NotImplementedError
  639. def publish_output(
  640. self,
  641. name: str,
  642. data: str,
  643. *,
  644. nowait: bool = False,
  645. ) -> None:
  646. # from vendor.protobuf import google3.protobuf.timestamp
  647. # ts = timestamp.Timestamp()
  648. # ts.GetCurrentTime()
  649. # now = datetime.now()
  650. if name == "stdout":
  651. otype = pb.OutputRecord.OutputType.STDOUT
  652. elif name == "stderr":
  653. otype = pb.OutputRecord.OutputType.STDERR
  654. else:
  655. # TODO(jhr): throw error?
  656. termwarn("unknown type")
  657. o = pb.OutputRecord(output_type=otype, line=data)
  658. o.timestamp.GetCurrentTime()
  659. self._publish_output(o, nowait=nowait)
  660. @abc.abstractmethod
  661. def _publish_output(self, outdata: pb.OutputRecord, *, nowait: bool) -> None:
  662. raise NotImplementedError
  663. def publish_output_raw(
  664. self,
  665. name: str,
  666. data: str,
  667. *,
  668. nowait: bool = False,
  669. ) -> None:
  670. # from vendor.protobuf import google3.protobuf.timestamp
  671. # ts = timestamp.Timestamp()
  672. # ts.GetCurrentTime()
  673. # now = datetime.now()
  674. if name == "stdout":
  675. otype = pb.OutputRawRecord.OutputType.STDOUT
  676. elif name == "stderr":
  677. otype = pb.OutputRawRecord.OutputType.STDERR
  678. else:
  679. # TODO(jhr): throw error?
  680. termwarn("unknown type")
  681. o = pb.OutputRawRecord(output_type=otype, line=data)
  682. o.timestamp.GetCurrentTime()
  683. self._publish_output_raw(o, nowait=nowait)
  684. @abc.abstractmethod
  685. def _publish_output_raw(
  686. self,
  687. outdata: pb.OutputRawRecord,
  688. *,
  689. nowait: bool,
  690. ) -> None:
  691. raise NotImplementedError
  692. def publish_pause(self) -> None:
  693. pause = pb.PauseRequest()
  694. self._publish_pause(pause)
  695. @abc.abstractmethod
  696. def _publish_pause(self, pause: pb.PauseRequest) -> None:
  697. raise NotImplementedError
  698. def publish_resume(self) -> None:
  699. resume = pb.ResumeRequest()
  700. self._publish_resume(resume)
  701. @abc.abstractmethod
  702. def _publish_resume(self, resume: pb.ResumeRequest) -> None:
  703. raise NotImplementedError
  704. def publish_alert(
  705. self, title: str, text: str, level: str, wait_duration: int
  706. ) -> None:
  707. proto_alert = pb.AlertRecord()
  708. proto_alert.title = title
  709. proto_alert.text = text
  710. proto_alert.level = level
  711. proto_alert.wait_duration = wait_duration
  712. self._publish_alert(proto_alert)
  713. @abc.abstractmethod
  714. def _publish_alert(self, alert: pb.AlertRecord) -> None:
  715. raise NotImplementedError
  716. def _make_exit(self, exit_code: int | None) -> pb.RunExitRecord:
  717. exit = pb.RunExitRecord()
  718. if exit_code is not None:
  719. exit.exit_code = exit_code
  720. return exit
  721. def publish_exit(self, exit_code: int | None) -> None:
  722. exit_data = self._make_exit(exit_code)
  723. self._publish_exit(exit_data)
  724. @abc.abstractmethod
  725. def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
  726. raise NotImplementedError
  727. def publish_keepalive(self) -> None:
  728. keepalive = pb.KeepaliveRequest()
  729. self._publish_keepalive(keepalive)
  730. @abc.abstractmethod
  731. def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
  732. raise NotImplementedError
  733. def publish_job_input(
  734. self,
  735. include_paths: list[list[str]],
  736. exclude_paths: list[list[str]],
  737. input_schema: dict | None,
  738. run_config: bool = False,
  739. file_path: str = "",
  740. ):
  741. """Publishes a request to add inputs to the job.
  742. If run_config is True, the wandb.config will be added as a job input.
  743. If file_path is provided, the file at file_path will be added as a job
  744. input.
  745. The paths provided as arguments are sequences of dictionary keys that
  746. specify a path within the wandb.config. If a path is included, the
  747. corresponding field will be treated as a job input. If a path is
  748. excluded, the corresponding field will not be treated as a job input.
  749. Args:
  750. include_paths: paths within config to include as job inputs.
  751. exclude_paths: paths within config to exclude as job inputs.
  752. input_schema: A JSON Schema describing which attributes will be
  753. editable from the Launch drawer.
  754. run_config: bool indicating whether wandb.config is the input source.
  755. file_path: path to file to include as a job input.
  756. """
  757. if run_config and file_path:
  758. raise ValueError(
  759. "run_config and file_path are mutually exclusive arguments."
  760. )
  761. request = pb.JobInputRequest()
  762. include_records = [pb.JobInputPath(path=path) for path in include_paths]
  763. exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths]
  764. request.include_paths.extend(include_records)
  765. request.exclude_paths.extend(exclude_records)
  766. source = pb.JobInputSource(
  767. run_config=pb.JobInputSource.RunConfigSource(),
  768. )
  769. if run_config:
  770. source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource())
  771. else:
  772. source.file.CopyFrom(
  773. pb.JobInputSource.ConfigFileSource(path=file_path),
  774. )
  775. request.input_source.CopyFrom(source)
  776. if input_schema:
  777. request.input_schema = json_dumps_safer(input_schema)
  778. return self._publish_job_input(request)
  779. @abc.abstractmethod
  780. def _publish_job_input(
  781. self, request: pb.JobInputRequest
  782. ) -> MailboxHandle[pb.Result]:
  783. raise NotImplementedError
  784. def publish_probe_system_info(self) -> None:
  785. probe_system_info = pb.ProbeSystemInfoRequest()
  786. return self._publish_probe_system_info(probe_system_info)
  787. @abc.abstractmethod
  788. def _publish_probe_system_info(
  789. self, probe_system_info: pb.ProbeSystemInfoRequest
  790. ) -> None:
  791. raise NotImplementedError
  792. def join(self) -> None:
  793. # Drop indicates that the internal process has already been shutdown
  794. if self._drop:
  795. return
  796. handle = self._deliver_shutdown()
  797. try:
  798. handle.wait_or(timeout=30)
  799. except TimeoutError:
  800. # This can happen if the server fails to respond due to a bug
  801. # or due to being very busy.
  802. logger.warning("timed out communicating shutdown")
  803. except HandleAbandonedError:
  804. # This can happen if the connection to the server is closed
  805. # before a response is read.
  806. logger.warning("handle abandoned while communicating shutdown")
  807. @abc.abstractmethod
  808. def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
  809. raise NotImplementedError
  810. def deliver_run(self, run: Run) -> MailboxHandle[pb.Result]:
  811. run_record = self._make_run(run)
  812. return self._deliver_run(run_record)
  813. def deliver_finish_sync(
  814. self,
  815. ) -> MailboxHandle[pb.Result]:
  816. sync = pb.SyncFinishRequest()
  817. return self._deliver_finish_sync(sync)
  818. @abc.abstractmethod
  819. def _deliver_finish_sync(
  820. self, sync: pb.SyncFinishRequest
  821. ) -> MailboxHandle[pb.Result]:
  822. raise NotImplementedError
  823. @abc.abstractmethod
  824. def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
  825. raise NotImplementedError
  826. def deliver_run_start(self, run: Run) -> MailboxHandle[pb.Result]:
  827. run_start = pb.RunStartRequest(run=self._make_run(run))
  828. return self._deliver_run_start(run_start)
  829. @abc.abstractmethod
  830. def _deliver_run_start(
  831. self, run_start: pb.RunStartRequest
  832. ) -> MailboxHandle[pb.Result]:
  833. raise NotImplementedError
  834. def deliver_attach(self, attach_id: str) -> MailboxHandle[pb.Result]:
  835. attach = pb.AttachRequest(attach_id=attach_id)
  836. return self._deliver_attach(attach)
  837. @abc.abstractmethod
  838. def _deliver_attach(
  839. self,
  840. status: pb.AttachRequest,
  841. ) -> MailboxHandle[pb.Result]:
  842. raise NotImplementedError
  843. def deliver_stop_status(self) -> MailboxHandle[pb.Result]:
  844. status = pb.StopStatusRequest()
  845. return self._deliver_stop_status(status)
  846. @abc.abstractmethod
  847. def _deliver_stop_status(
  848. self,
  849. status: pb.StopStatusRequest,
  850. ) -> MailboxHandle[pb.Result]:
  851. raise NotImplementedError
  852. def deliver_network_status(self) -> MailboxHandle[pb.Result]:
  853. status = pb.NetworkStatusRequest()
  854. return self._deliver_network_status(status)
  855. @abc.abstractmethod
  856. def _deliver_network_status(
  857. self,
  858. status: pb.NetworkStatusRequest,
  859. ) -> MailboxHandle[pb.Result]:
  860. raise NotImplementedError
  861. def deliver_internal_messages(self) -> MailboxHandle[pb.Result]:
  862. internal_message = pb.InternalMessagesRequest()
  863. return self._deliver_internal_messages(internal_message)
  864. @abc.abstractmethod
  865. def _deliver_internal_messages(
  866. self, internal_message: pb.InternalMessagesRequest
  867. ) -> MailboxHandle[pb.Result]:
  868. raise NotImplementedError
  869. def deliver_get_summary(self) -> MailboxHandle[pb.Result]:
  870. get_summary = pb.GetSummaryRequest()
  871. return self._deliver_get_summary(get_summary)
  872. @abc.abstractmethod
  873. def _deliver_get_summary(
  874. self,
  875. get_summary: pb.GetSummaryRequest,
  876. ) -> MailboxHandle[pb.Result]:
  877. raise NotImplementedError
  878. def deliver_get_system_metrics(self) -> MailboxHandle[pb.Result]:
  879. get_system_metrics = pb.GetSystemMetricsRequest()
  880. return self._deliver_get_system_metrics(get_system_metrics)
  881. @abc.abstractmethod
  882. def _deliver_get_system_metrics(
  883. self, get_summary: pb.GetSystemMetricsRequest
  884. ) -> MailboxHandle[pb.Result]:
  885. raise NotImplementedError
  886. def deliver_exit(self, exit_code: int | None) -> MailboxHandle[pb.Result]:
  887. exit_data = self._make_exit(exit_code)
  888. return self._deliver_exit(exit_data)
  889. @abc.abstractmethod
  890. def _deliver_exit(
  891. self,
  892. exit_data: pb.RunExitRecord,
  893. ) -> MailboxHandle[pb.Result]:
  894. raise NotImplementedError
  895. def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
  896. poll_exit = pb.PollExitRequest()
  897. return self._deliver_poll_exit(poll_exit)
  898. @abc.abstractmethod
  899. def _deliver_poll_exit(
  900. self,
  901. poll_exit: pb.PollExitRequest,
  902. ) -> MailboxHandle[pb.Result]:
  903. raise NotImplementedError
  904. def deliver_finish_without_exit(self) -> MailboxHandle[pb.Result]:
  905. run_finish_without_exit = pb.RunFinishWithoutExitRequest()
  906. return self._deliver_finish_without_exit(run_finish_without_exit)
  907. @abc.abstractmethod
  908. def _deliver_finish_without_exit(
  909. self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
  910. ) -> MailboxHandle[pb.Result]:
  911. raise NotImplementedError
  912. def deliver_request_sampled_history(self) -> MailboxHandle[pb.Result]:
  913. sampled_history = pb.SampledHistoryRequest()
  914. return self._deliver_request_sampled_history(sampled_history)
  915. @abc.abstractmethod
  916. def _deliver_request_sampled_history(
  917. self, sampled_history: pb.SampledHistoryRequest
  918. ) -> MailboxHandle[pb.Result]:
  919. raise NotImplementedError
  920. def deliver_request_run_status(self) -> MailboxHandle[pb.Result]:
  921. run_status = pb.RunStatusRequest()
  922. return self._deliver_request_run_status(run_status)
  923. @abc.abstractmethod
  924. def _deliver_request_run_status(
  925. self, run_status: pb.RunStatusRequest
  926. ) -> MailboxHandle[pb.Result]:
  927. raise NotImplementedError