config.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """
  2. This is the top-level configuration module for the compiler, containing
  3. cross-cutting configuration options that affect all parts of the compiler
  4. stack.
  5. You may also be interested in the per-component configuration modules, which
  6. contain configuration options that affect only a specific part of the compiler:
  7. * :mod:`torch._dynamo.config`
  8. * :mod:`torch._inductor.config`
  9. * :mod:`torch._functorch.config`
  10. * :mod:`torch.fx.experimental.config`
  11. """
  12. import sys
  13. from typing import Optional
  14. from torch.utils._config_module import Config, install_config_module
  15. __all__ = [
  16. "job_id",
  17. ]
  18. # NB: Docblocks go UNDER variable definitions! Use spacing to make the
  19. # grouping clear.
  20. # FB-internal note: you do NOT have to specify this explicitly specify this if
  21. # you run on MAST, we will automatically default this to
  22. # mast:MAST_JOB_NAME:MAST_JOB_VERSION.
  23. job_id: Optional[str] = Config(
  24. env_name_default=["TORCH_COMPILE_JOB_ID", "TORCH_COMPILE_STICKY_PGO_KEY"],
  25. default=None,
  26. )
  27. """
  28. Semantically, this should be an identifier that uniquely identifies, e.g., a
  29. training job. You might have multiple attempts of the same job, e.g., if it was
  30. preempted or needed to be restarted, but each attempt should be running
  31. substantially the same workload with the same distributed topology. You can
  32. set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`.
  33. Operationally, this controls the effect of profile-guided optimization related
  34. persistent state. PGO state can affect how we perform compilation across
  35. multiple invocations of PyTorch, e.g., the first time you run your program we
  36. may compile twice as we discover what inputs are dynamic, and then PGO will
  37. save this state so subsequent invocations only need to compile once, because
  38. they remember it is dynamic. This profile information, however, is sensitive
  39. to what workload you are running, so we require you to tell us that two jobs
  40. are *related* (i.e., are the same workload) before we are willing to reuse
  41. this information. Notably, PGO does nothing (even if explicitly enabled)
  42. unless a valid ``job_id`` is available. In some situations, PyTorch can
  43. configured to automatically compute a ``job_id`` based on the environment it
  44. is running in.
  45. Profiles are always collected on a per rank basis, so different ranks may have
  46. different profiles. If you know your workload is truly SPMD, you can run with
  47. :data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get
  48. consistent profiles across all ranks.
  49. """
  50. pgo_extra_read_key: Optional[str] = Config(
  51. env_name_default="TORCH_COMPILE_STICKY_PGO_READ", default=None
  52. )
  53. pgo_extra_write_key: Optional[str] = Config(
  54. env_name_default="TORCH_COMPILE_STICKY_PGO_WRITE", default=None
  55. )
  56. """
  57. Additional read/write keys for PGO.
  58. Write key: Besides writing to the default local/remote PGO state, this also writes to the specified key.
  59. Read key: Besides reading from the default state, this also reads from the specified key (if written to before)
  60. and merges it with the default state.
  61. """
  62. cache_key_tag: str = Config(env_name_default="TORCH_COMPILE_CACHE_KEY_TAG", default="")
  63. """
  64. Tag to be included in the cache key generation for all torch compile caching.
  65. A common use case for such a tag is to break caches.
  66. """
  67. force_disable_caches: bool = Config(
  68. justknob="pytorch/remote_cache:force_disable_caches",
  69. env_name_force=[
  70. "TORCHINDUCTOR_FORCE_DISABLE_CACHES",
  71. "TORCH_COMPILE_FORCE_DISABLE_CACHES",
  72. ],
  73. default=False,
  74. )
  75. """
  76. Force disables all caching -- This will take precedence over and override any other caching flag
  77. """
  78. dynamic_sources: str = Config(
  79. env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default=""
  80. )
  81. """
  82. Comma delimited list of sources that should be marked as dynamic. Primarily useful for large
  83. models with graph breaks where you need intermediate tensors and ints to be marked dynamic.
  84. This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
  85. and force_parameter_static_shapes.
  86. """
  87. unbacked_sources: str = Config(
  88. env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default=""
  89. )
  90. """
  91. Comma delimited list of sources that should be marked as unbacked. Primarily useful for large
  92. models with graph breaks where you need intermediate tensors marked unbacked.
  93. This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
  94. and force_parameter_static_shapes.
  95. """
  96. # force a python GC before recording cudagraphs
  97. force_cudagraph_gc: bool = Config(env_name_default="TORCH_CUDAGRAPH_GC", default=False)
  98. """
  99. If True (the backward-compatible behavior) then gc.collect() before recording
  100. any cudagraph.
  101. """
  102. install_config_module(sys.modules[__name__])