| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189 |
- # mypy: allow-untyped-defs
- import collections
- import importlib.machinery
- import io
- import linecache
- import os
- import pickletools
- import platform
- import types
- from collections import defaultdict, OrderedDict
- from collections.abc import Sequence
- from dataclasses import dataclass
- from enum import Enum
- from importlib.machinery import SourceFileLoader
- from pathlib import Path
- from typing import Any, Callable, cast, IO, Optional, Union
- import torch
- from torch.serialization import location_tag, normalize_storage_type
- from torch.types import FileLike, Storage
- from torch.utils.hooks import RemovableHandle
- from ._digraph import DiGraph
- from ._importlib import _normalize_path
- from ._mangling import demangle, is_mangled
- from ._package_pickler import create_pickler
- from ._stdlib import is_stdlib_module
- from .find_file_dependencies import find_files_source_depends_on
- from .glob_group import GlobGroup, GlobPattern
- from .importer import Importer, OrderedImporter, sys_importer
- __all__ = [
- "PackagingErrorReason",
- "EmptyMatchError",
- "PackagingError",
- "PackageExporter",
- ]
- _gate_torchscript_serialization = True
- ActionHook = Callable[["PackageExporter", str], None]
- class _ModuleProviderAction(Enum):
- """Represents one of the actions that :class:`PackageExporter` can take on a module.
- See :meth:`PackageExporter.extern` and friends for a description of what the actions do.
- """
- INTERN = 1
- EXTERN = 2
- MOCK = 3
- DENY = 4
- # Special case: when a module is mocked, PackageExporter writes out a
- # `_mock` module that implements our mocking stubs. If we re-package code,
- # we may encounter a `_mock` module from the original package. If we do,
- # just ignore it and write a `_mock` module once.
- REPACKAGED_MOCK_MODULE = 5
- # Special case: PackageImporter adds a fake module
- # (`torch_package_importer`) that allows packaged code to access it. Don't
- # re-export this.
- SKIP = 6
- class PackagingErrorReason(Enum):
- """Listing of different reasons a dependency may fail to package.
- This enum is used to provide good error messages when
- :class:`PackagingError` is raised.
- """
- def __repr__(self):
- return f"<{self.__class__.__name__}.{self.name}>"
- IS_EXTENSION_MODULE = (
- "Module is a C extension module. torch.package supports Python modules only."
- )
- NO_DUNDER_FILE = "Module had no __file__ defined."
- SOURCE_FILE_NOT_FOUND = (
- "Module had a __file__, but we could not find it in your filesystem."
- )
- DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed."
- NO_ACTION = (
- "Module did not match against any action pattern. Extern, mock, or intern it."
- )
- DENIED = "Module was denied by a pattern."
- MOCKED_BUT_STILL_USED = (
- "Module was mocked out, but is still being used in the package. "
- "Please intern or extern the mocked modules if objects are supposed to be in "
- "the package."
- )
- @dataclass
- class _PatternInfo:
- """Holds :class:`PackageExporter`-specific info about how to execute matches against"""
- # What action to take on a module that matches this pattern.
- action: _ModuleProviderAction
- # The value of `allow_empty` the user gave when specifying the pattern.
- allow_empty: bool
- # Whether this pattern has been matched during packaging.
- was_matched: bool
- def __init__(self, action, allow_empty):
- self.action = action
- self.allow_empty = allow_empty
- self.was_matched = False
- class EmptyMatchError(Exception):
- """This is an exception that is thrown when a mock or extern is marked as
- ``allow_empty=False``, and is not matched with any module during packaging.
- """
- class PackagingError(Exception):
- """This exception is raised when there is an issue with exporting a package.
- ``PackageExporter`` will attempt to gather up all the errors and present
- them to you at once.
- """
- def __init__(self, dependency_graph: DiGraph, debug=False):
- # Group errors by reason.
- broken: dict[PackagingErrorReason, list[str]] = defaultdict(list)
- for module_name, attrs in dependency_graph.nodes.items():
- error = attrs.get("error")
- if error is None:
- continue
- if error == PackagingErrorReason.NO_ACTION:
- assert "action" not in attrs
- broken[error].append(module_name)
- message = io.StringIO()
- message.write("\n")
- for reason, module_names in broken.items():
- message.write(f"* {reason.value}\n")
- for module_name in module_names:
- message.write(f" {module_name}\n")
- # Print additional context if it's provided.
- error_context = dependency_graph.nodes[module_name].get("error_context")
- if error_context is not None:
- message.write(f" Context: {error_context}\n")
- if module_name in _DISALLOWED_MODULES:
- message.write(
- " Note: While we usually use modules in the python standard library "
- f"from the local environment, `{module_name}` has a lot of system "
- "level access and therefore can pose a security risk. We heavily "
- f"recommend removing `{module_name}` from your packaged code. However, if that "
- "is not possible, add it to the extern list by calling "
- f'PackageExporter.extern("`{module_name}`")\n'
- )
- if debug:
- module_path = dependency_graph.first_path(module_name)
- message.write(
- f" A path to {module_name}: {' -> '.join(module_path)}\n"
- )
- if not debug:
- message.write("\n")
- message.write(
- "Set debug=True when invoking PackageExporter for a visualization of where "
- "broken modules are coming from!\n"
- )
- # Save the dependency graph so that tooling can get at it.
- self.dependency_graph = dependency_graph
- super().__init__(message.getvalue())
- class PackageExporter:
- """Exporters allow you to write packages of code, pickled Python data, and
- arbitrary binary and text resources into a self-contained package.
- Imports can load this code in a hermetic way, such that code is loaded
- from the package rather than the normal Python import system. This allows
- for the packaging of PyTorch model code and data so that it can be run
- on a server or used in the future for transfer learning.
- The code contained in packages is copied file-by-file from the original
- source when it is created, and the file format is a specially organized
- zip file. Future users of the package can unzip the package, and edit the code
- in order to perform custom modifications to it.
- The importer for packages ensures that code in the module can only be loaded from
- within the package, except for modules explicitly listed as external using :meth:`extern`.
- The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
- This prevents "implicit" dependencies where the package runs locally because it is importing
- a locally-installed package, but then fails when the package is copied to another machine.
- When source code is added to the package, the exporter can optionally scan it
- for further code dependencies (``dependencies=True``). It looks for import statements,
- resolves relative references to qualified module names, and performs an action specified by the user
- (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`).
- """
- """A importer that will be searched in order to find the modules referenced by other modules or by
- pickled objects. The default module environment just uses sys_importer, which searches the Python environment.
- """
- importer: Importer
- def __init__(
- self,
- f: FileLike,
- importer: Union[Importer, Sequence[Importer]] = sys_importer,
- debug: bool = False,
- ) -> None:
- """
- Create an exporter.
- Args:
- f: The location to export to. Can be a ``string``/``Path`` object containing a filename
- or a binary I/O object.
- importer: If a single Importer is passed, use that to search for modules.
- If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them.
- debug: If set to True, add path of broken modules to PackagingErrors.
- """
- torch._C._log_api_usage_once("torch.package.PackageExporter")
- self.debug = debug
- if isinstance(f, (str, os.PathLike)):
- f = os.fspath(f)
- self.buffer: Optional[IO[bytes]] = None
- else: # is a byte buffer
- self.buffer = f
- self.zip_file = torch._C.PyTorchFileWriter(f)
- self.zip_file.set_min_version(6)
- self._written_files: set[str] = set()
- self.serialized_reduces: dict[int, Any] = {}
- # A graph tracking all the modules and pickle objects added to this
- # package and the dependencies between them.
- # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>')
- # - Each directed edge (u, v) means u depends on v.
- # - Nodes may contain metadata that describe how to write the thing to the zipfile.
- self.dependency_graph = DiGraph()
- self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
- self.storage_context = self.script_module_serializer.storage_context()
- # These are OrderedDicts for compatibility with RemovableHandle.
- # Generic OrderedDict type annotations are not present until 3.7.
- # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
- self._extern_hooks: OrderedDict = OrderedDict()
- self._mock_hooks: OrderedDict = OrderedDict()
- self._intern_hooks: OrderedDict = OrderedDict()
- if isinstance(importer, Importer):
- self.importer = importer
- else:
- if not isinstance(importer, collections.abc.Sequence):
- raise TypeError(
- "importer arg should be an Importer or a sequence of Importers, "
- f"got {type(importer)} instead."
- )
- self.importer = OrderedImporter(*importer)
- self.patterns: dict[GlobGroup, _PatternInfo] = {}
- self._unique_id = 0
- def save_source_file(
- self, module_name: str, file_or_directory: str, dependencies=True
- ):
- """Adds the local file system ``file_or_directory`` to the source package to provide the code
- for ``module_name``.
- Args:
- module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package.
- file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory
- are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated
- as a package.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- path = Path(file_or_directory)
- if path.is_dir():
- to_save = [] # list of tuples with arguments to save_source_string
- module_path = module_name.replace(".", "/")
- for filename in path.glob("**/*.py"):
- relative_path = filename.relative_to(path).as_posix()
- archivename = module_path + "/" + relative_path
- submodule_name = None
- if filename.name == "__init__.py":
- submodule_name = archivename[: -len("/__init__.py")].replace(
- "/", "."
- )
- is_package = True
- else:
- submodule_name = archivename[: -len(".py")].replace("/", ".")
- is_package = False
- # we delay the call to save_source_string so that we record all the source files
- # being provided by this directory structure _before_ attempting to resolve the dependencies
- # on the source. This makes sure we don't try to copy over modules that will just get
- # overwritten by this directory blob
- to_save.append(
- (
- submodule_name,
- _read_file(str(filename)),
- is_package,
- dependencies,
- )
- )
- for item in to_save:
- self.save_source_string(*item)
- else:
- is_package = path.name == "__init__.py"
- self.save_source_string(
- module_name,
- _read_file(file_or_directory),
- is_package,
- dependencies,
- )
- def get_unique_id(self) -> str:
- """Get an id. This id is guaranteed to only be handed out once for this package."""
- ret = str(self._unique_id)
- self._unique_id += 1
- return ret
- def _get_dependencies(
- self, src: str, module_name: str, is_package: bool
- ) -> list[str]:
- """Return all modules that this source code depends on.
- Dependencies are found by scanning the source code for import-like statements.
- Arguments:
- src: The Python source code to analyze for dependencies.
- module_name: The name of the module that ``src`` corresponds to.
- is_package: Whether this module should be treated as a package.
- See :py:meth:`save_source_string` for more info.
- Returns:
- A list containing modules detected as direct dependencies in
- ``src``. The items in the list are guaranteed to be unique.
- """
- package_name = (
- module_name if is_package else module_name.rsplit(".", maxsplit=1)[0]
- )
- try:
- dep_pairs = find_files_source_depends_on(src, package_name)
- except Exception as e:
- self.dependency_graph.add_node(
- module_name,
- error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED,
- error_context=str(e),
- )
- return []
- # Use a dict to get uniquing but also deterministic order
- dependencies = {}
- for dep_module_name, dep_module_obj in dep_pairs:
- # handle the case where someone did something like `from pack import sub`
- # where `sub` is a submodule. In this case we don't have to save pack, just sub.
- # this ensures we don't pick up additional dependencies on pack.
- # However, in the case where `sub` is not a submodule but an object, then we do have
- # to save pack.
- if dep_module_obj is not None:
- possible_submodule = f"{dep_module_name}.{dep_module_obj}"
- if self._module_exists(possible_submodule):
- dependencies[possible_submodule] = True
- # we don't need to save `pack`
- continue
- if self._module_exists(dep_module_name):
- dependencies[dep_module_name] = True
- return list(dependencies.keys())
- def save_source_string(
- self,
- module_name: str,
- src: str,
- is_package: bool = False,
- dependencies: bool = True,
- ):
- """Adds ``src`` as the source code for ``module_name`` in the exported package.
- Args:
- module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package.
- src (str): The Python source code to save for this package.
- is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules
- (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- self.dependency_graph.add_node(
- module_name,
- source=src,
- is_package=is_package,
- provided=True,
- action=_ModuleProviderAction.INTERN,
- )
- if dependencies:
- deps = self._get_dependencies(src, module_name, is_package)
- for dep in deps:
- self.dependency_graph.add_edge(module_name, dep)
- self.add_dependency(dep)
- def _write_source_string(
- self,
- module_name: str,
- src: str,
- is_package: bool = False,
- ):
- """Write ``src`` as the source code for ``module_name`` in the zip archive.
- Arguments are otherwise the same as for :meth:`save_source_string`.
- """
- extension = "/__init__.py" if is_package else ".py"
- filename = module_name.replace(".", "/") + extension
- self._write(filename, src)
- def _import_module(self, module_name: str):
- try:
- return self.importer.import_module(module_name)
- except ModuleNotFoundError:
- if not is_mangled(module_name):
- raise
- msg = (
- f"Module not found: '{module_name}'. Make sure the PackageImporter that "
- "created this module is present in `self.importer`"
- )
- raise ModuleNotFoundError(msg) from None
- def _module_exists(self, module_name: str) -> bool:
- try:
- self._import_module(module_name)
- return True
- except Exception:
- return False
- def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
- filename = None
- spec = getattr(module, "__spec__", None)
- if spec is not None:
- loader = getattr(spec, "loader", None)
- if loader is not None and isinstance(loader, SourceFileLoader):
- try:
- filename = loader.get_filename(module.__name__)
- except ImportError:
- pass
- if filename is None:
- filename = getattr(module, "__file__", None)
- if isinstance(filename, str) and filename.endswith(".py"):
- return "".join(linecache.getlines(filename, module.__dict__))
- return None
- def add_dependency(self, module_name: str, dependencies=True):
- """Given a module, add it to the dependency graph according to patterns
- specified by the user.
- """
- if (
- module_name in self.dependency_graph
- and self.dependency_graph.nodes[module_name].get("provided") is True
- ):
- return
- # Special case: PackageImporter provides a special module called
- # `torch_package_importer` that allows packaged modules to reference
- # their PackageImporter. We don't want to re-export this.
- if module_name == "torch_package_importer":
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.SKIP,
- provided=True,
- )
- return
- if module_name == "_mock":
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE,
- provided=True,
- )
- return
- if self._can_implicitly_extern(module_name):
- self.dependency_graph.add_node(
- module_name, action=_ModuleProviderAction.EXTERN, provided=True
- )
- return
- for pattern, pattern_info in self.patterns.items():
- if pattern.matches(module_name):
- pattern_info.was_matched = True
- self.dependency_graph.add_node(
- module_name, action=pattern_info.action, provided=True
- )
- if pattern_info.action == _ModuleProviderAction.DENY:
- # Requiring a denied module just adds an error to the graph.
- self.dependency_graph.add_node(
- module_name, error=PackagingErrorReason.DENIED
- )
- # If we are interning this module, we need to retrieve its
- # dependencies and package those as well.
- if pattern_info.action == _ModuleProviderAction.INTERN:
- self._intern_module(module_name, dependencies)
- return
- # No patterns have matched. Explicitly add this as an error.
- self.dependency_graph.add_node(
- module_name, error=PackagingErrorReason.NO_ACTION
- )
- def save_module(self, module_name: str, dependencies=True):
- """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the
- module object, and then using its ``__file__`` attribute to find the source code.
- Args:
- module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code
- for this package.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- if not isinstance(module_name, str):
- raise TypeError(
- "save_module() expects a string input, did you perhaps mean to pass `__name__`?"
- )
- self._intern_module(module_name, dependencies)
- def _intern_module(
- self,
- module_name: str,
- dependencies: bool,
- ):
- """Adds the module to the dependency graph as an interned module,
- along with any metadata needed to write it out to the zipfile at serialization time.
- """
- module_obj = self._import_module(module_name)
- # Subtle: if the import above succeeded, either:
- # 1. The module name is not mangled, and this was just a regular import, or
- # 2. The module name is mangled, but one of the importers was able to
- # recognize the mangling and import it.
- # Either way, it is now safe to demangle this name so that we don't
- # serialize the mangled version to the package.
- module_name = demangle(module_name)
- # Find dependencies of this module and require them as well.
- is_package = hasattr(module_obj, "__path__")
- source = self._get_source_of_module(module_obj)
- if source is None:
- # Couldn't find a source! Add it to our dependency graph as broken
- # and continue.
- filename = getattr(module_obj, "__file__", None)
- error_context = None
- if filename is None:
- packaging_error = PackagingErrorReason.NO_DUNDER_FILE
- elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)):
- packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE
- else:
- packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND
- error_context = f"filename: {filename}"
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.INTERN,
- is_package=is_package,
- error=packaging_error,
- error_context=error_context,
- provided=True,
- )
- return
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.INTERN,
- is_package=is_package,
- source=source,
- provided=True,
- )
- if dependencies:
- deps = self._get_dependencies(source, module_name, is_package)
- for dep in deps:
- self.dependency_graph.add_edge(module_name, dep)
- self.add_dependency(dep)
- def save_pickle(
- self,
- package: str,
- resource: str,
- obj: Any,
- dependencies: bool = True,
- pickle_protocol: int = 3,
- ):
- """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
- the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects.
- If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required
- to reconstruct them and save the relevant code.
- To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``,
- ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that
- have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list
- for this to work.
- Args:
- package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- obj (Any): The object to save, must be picklable.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- assert (pickle_protocol == 4) or (pickle_protocol == 3), (
- "torch.package only supports pickle protocols 3 and 4"
- )
- filename = self._filename(package, resource)
- # Write the pickle data for `obj`
- data_buf = io.BytesIO()
- pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
- pickler.persistent_id = self._persistent_id
- pickler.dump(obj)
- data_value = data_buf.getvalue()
- mocked_modules = defaultdict(list)
- name_in_dependency_graph = f"<{package}.{resource}>"
- self.dependency_graph.add_node(
- name_in_dependency_graph,
- action=_ModuleProviderAction.INTERN,
- provided=True,
- is_pickle=True,
- )
- def _check_mocked_error(module: Optional[str], field: Optional[str]):
- """
- checks if an object (field) comes from a mocked module and then adds
- the pair to mocked_modules which contains mocked modules paired with their
- list of mocked objects present in the pickle.
- We also hold the invariant that the first user defined rule that applies
- to the module is the one we use.
- """
- assert isinstance(module, str)
- assert isinstance(field, str)
- if self._can_implicitly_extern(module):
- return
- for pattern, pattern_info in self.patterns.items():
- if pattern.matches(module):
- if pattern_info.action == _ModuleProviderAction.MOCK:
- mocked_modules[module].append(field)
- return
- if dependencies:
- all_dependencies = []
- module = None
- field = None
- memo: defaultdict[int, str] = defaultdict(None)
- memo_count = 0
- # pickletools.dis(data_value)
- for opcode, arg, _pos in pickletools.genops(data_value):
- if pickle_protocol == 4:
- if (
- opcode.name == "SHORT_BINUNICODE"
- or opcode.name == "BINUNICODE"
- or opcode.name == "BINUNICODE8"
- ):
- assert isinstance(arg, str)
- module = field
- field = arg
- memo[memo_count] = arg
- elif (
- opcode.name == "LONG_BINGET"
- or opcode.name == "BINGET"
- or opcode.name == "GET"
- ):
- assert isinstance(arg, int)
- module = field
- field = memo.get(arg, None)
- elif opcode.name == "MEMOIZE":
- memo_count += 1
- elif opcode.name == "STACK_GLOBAL":
- if module is None:
- # If not module was passed on in the entries preceding this one, continue.
- continue
- assert isinstance(module, str)
- if module not in all_dependencies:
- all_dependencies.append(module)
- _check_mocked_error(module, field)
- elif (
- pickle_protocol == 3 and opcode.name == "GLOBAL"
- ): # a global reference
- assert isinstance(arg, str)
- module, field = arg.split(" ")
- if module not in all_dependencies:
- all_dependencies.append(module)
- _check_mocked_error(module, field)
- for module_name in all_dependencies:
- self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
- """ If an object happens to come from a mocked module, then we collect these errors and spit them
- out with the other errors found by package exporter.
- """
- if module_name in mocked_modules:
- assert isinstance(module_name, str)
- fields = mocked_modules[module_name]
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.MOCK,
- error=PackagingErrorReason.MOCKED_BUT_STILL_USED,
- error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging "
- f"but is being used in resource - `{resource}` in package `{package}`. ",
- provided=True,
- )
- else:
- self.add_dependency(module_name)
- self._write(filename, data_value)
- def save_text(self, package: str, resource: str, text: str):
- """Save text data to the package.
- Args:
- package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- text (str): The contents to save.
- """
- return self.save_binary(package, resource, text.encode("utf-8"))
- def save_binary(self, package, resource, binary: bytes):
- """Save raw bytes to the package.
- Args:
- package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- binary (str): The data to save.
- """
- filename = self._filename(package, resource)
- self._write(filename, binary)
- def register_extern_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers an extern hook on the exporter.
- The hook will be called each time a module matches against an :meth:`extern` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._extern_hooks)
- self._extern_hooks[handle.id] = hook
- return handle
- def register_mock_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers a mock hook on the exporter.
- The hook will be called each time a module matches against a :meth:`mock` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._mock_hooks)
- self._mock_hooks[handle.id] = hook
- return handle
- def register_intern_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers an intern hook on the exporter.
- The hook will be called each time a module matches against an :meth:`intern` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._intern_hooks)
- self._intern_hooks[handle.id] = hook
- return handle
- def intern(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be
- included in the package and have its dependencies processed recursively.
- Args:
- include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call
- to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob
- pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``)
- before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.INTERN, allow_empty
- )
- def mock(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Replace some required modules with a mock implementation. Mocked modules will return a fake
- object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes
- find files that are imported by model files but whose functionality is never used
- (e.g. custom serialization code or training helpers).
- Use this function to mock this functionality out without having to modify the original code.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be mocked out. Strings can also be a glob-style pattern
- string that may match multiple modules. Any required dependencies that match this pattern
- string will be mocked out automatically.
- Examples :
- ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'``
- and ``'torch.nn.functional'``
- ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not
- ``'torch.nn.functional'``
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``,
- Default: is ``[]``.
- allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call
- to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with
- ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has
- not been matched to a module used by the package being exported, an exception is thrown.
- If ``allow_empty=True``, no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.MOCK, allow_empty
- )
- def extern(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Include ``module`` in the list of external modules the package can import.
- This will prevent dependency discovery from saving
- it in the package. The importer will load an external module directly from the standard import system.
- Code for extern modules must also exist in the process loading the package.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as
- described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the
- include string.
- allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call
- to the ``extern`` method must be matched to some module during packaging. If an extern module glob
- pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via
- ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``,
- no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.EXTERN, allow_empty
- )
- def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
- """Blocklist modules who names match the given glob patterns from the list of modules the package can import.
- If a dependency on any matching packages is found, a :class:`PackagingError` is raised.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.DENY, allow_empty=True
- )
- def _persistent_id(self, obj):
- if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
- storage: Storage
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, we can
- # remove this case
- untyped_storage = obj._untyped_storage
- storage_type_str = obj.pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- storage = cast(Storage, untyped_storage)
- storage_numel = obj.size()
- elif isinstance(obj, torch.UntypedStorage):
- untyped_storage = obj
- storage = cast(Storage, untyped_storage)
- storage_type = normalize_storage_type(type(storage))
- storage_numel = storage.nbytes()
- else:
- raise RuntimeError(f"storage type not recognized: {type(obj)}")
- location = location_tag(storage)
- # serialize storage if not already written
- storage_present = self.storage_context.has_storage(storage)
- storage_id = self.storage_context.get_or_add_storage(storage)
- if not storage_present:
- if storage.device.type != "cpu":
- storage = storage.cpu()
- num_bytes = storage.nbytes()
- self.zip_file.write_record(
- f".data/{storage_id}.storage", storage, num_bytes
- )
- return ("storage", storage_type, storage_id, location, storage_numel)
- if hasattr(obj, "__reduce_package__"):
- if _gate_torchscript_serialization and isinstance(
- obj, torch.jit.RecursiveScriptModule
- ):
- raise Exception( # noqa: TRY002
- "Serializing ScriptModules directly into a package is a beta feature. "
- "To use, set global "
- "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
- )
- if self.serialized_reduces.get(id(obj)) is None:
- self.serialized_reduces[id(obj)] = (
- "reduce_package",
- id(obj),
- *obj.__reduce_package__(self),
- )
- return self.serialized_reduces[id(obj)]
- return None
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- # If __exit__ was called because an exception was raised, we do not
- # attempt to finalize the package. Instead, control is returned to the
- # caller to continue raising the exception.
- if exc_type is not None:
- # Do the bare minimum to leave the open buffer in a valid state.
- self._finalize_zip()
- return
- self.close()
- def _write(self, filename, str_or_bytes):
- if filename in self._written_files:
- raise AssertionError(
- f"Tried to write file '{filename}', but it already exists in this archive. "
- "Please file a bug."
- )
- self._written_files.add(filename)
- if is_mangled(filename):
- raise AssertionError(
- f"Tried to save a torch.package'd module as '{filename}'. "
- "Directly saving torch.package'd modules is not allowed."
- )
- if isinstance(str_or_bytes, str):
- str_or_bytes = str_or_bytes.encode("utf-8")
- self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes))
- def _validate_dependency_graph(self):
- # 1. Check the graph for any errors inserted during dependency analysis.
- for attrs in self.dependency_graph.nodes.values():
- if "error" in attrs:
- raise PackagingError(self.dependency_graph, debug=self.debug)
- # 2. Check that all patterns for which allow_empty=False have been matched at least once.
- for pattern, pattern_info in self.patterns.items():
- if not pattern_info.allow_empty and not pattern_info.was_matched:
- raise EmptyMatchError(
- f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False"
- )
- def _write_mock_file(self):
- if "_mock.py" not in self._written_files:
- mock_file = str(Path(__file__).parent / "_mock.py")
- self._write_source_string("_mock", _read_file(mock_file), is_package=False)
- def _execute_dependency_graph(self):
- """Takes a finalized dependency graph describing how to package all
- modules and executes it, writing to the ZIP archive.
- """
- self._validate_dependency_graph()
- extern_modules = []
- for module_name, attrs in self.dependency_graph.nodes.items():
- action = attrs["action"]
- if action == _ModuleProviderAction.EXTERN:
- for hook in self._extern_hooks.values():
- hook(self, module_name)
- extern_modules.append(module_name)
- elif action == _ModuleProviderAction.MOCK:
- for hook in self._mock_hooks.values():
- hook(self, module_name)
- self._write_mock_file()
- is_package = hasattr(self._import_module(module_name), "__path__")
- self._write_source_string(module_name, _MOCK_IMPL, is_package)
- elif action == _ModuleProviderAction.INTERN:
- for hook in self._intern_hooks.values():
- hook(self, module_name)
- # The node in the dependency graph contains metadata that tells us
- # how to intern the module.
- if "provided" not in attrs:
- raise AssertionError(
- f"Module was marked `intern` but not provided: {module_name}"
- )
- if attrs.get("is_pickle") is True:
- # This node came from save_pickle, we don't need to write any source for it.
- continue
- is_package = attrs["is_package"]
- source = attrs["source"]
- self._write_source_string(module_name, source, is_package)
- elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE:
- self._write_mock_file()
- elif action == _ModuleProviderAction.SKIP:
- continue
- else:
- raise AssertionError(
- f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch."
- )
- extern_file_contents = "\n".join(extern_modules) + "\n"
- self._write(".data/extern_modules", extern_file_contents)
- def _write_python_version(self):
- """Writes the python version that the package was created with to .data/python_version"""
- self._write(".data/python_version", platform.python_version())
- def close(self):
- """Write the package to the filesystem. Any calls after :meth:`close` are now invalid.
- It is preferable to use resource guard syntax instead::
- with PackageExporter("file.zip") as e:
- ...
- """
- self._execute_dependency_graph()
- self._write_python_version()
- self.script_module_serializer.write_files()
- self._finalize_zip()
- def _finalize_zip(self):
- """Called at the very end of packaging to leave the zipfile in a closed but valid state."""
- del self.zip_file
- if self.buffer:
- self.buffer.flush()
- def _filename(self, package, resource):
- package_path = package.replace(".", "/")
- resource = _normalize_path(resource)
- return f"{package_path}/{resource}"
- def _can_implicitly_extern(self, module_name: str):
- top_level_package_name = module_name.partition(".")[0]
- return top_level_package_name == "torch" or (
- top_level_package_name not in _DISALLOWED_MODULES
- and is_stdlib_module(top_level_package_name)
- )
- def dependency_graph_string(self) -> str:
- """Returns digraph string representation of dependencies in package.
- Returns:
- A string representation of dependencies in package.
- """
- return self.dependency_graph.to_dot()
- def _nodes_with_action_type(
- self, action: Optional[_ModuleProviderAction]
- ) -> list[str]:
- result = []
- for name, node_dict in self.dependency_graph.nodes.items():
- node_action = node_dict.get("action", None)
- if node_action == action and "is_pickle" not in node_dict:
- result.append(name)
- result.sort()
- return result
- def externed_modules(self) -> list[str]:
- """Return all modules that are currently externed.
- Returns:
- A list containing the names of modules which will be
- externed in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
- def interned_modules(self) -> list[str]:
- """Return all modules that are currently interned.
- Returns:
- A list containing the names of modules which will be
- interned in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
- def mocked_modules(self) -> list[str]:
- """Return all modules that are currently mocked.
- Returns:
- A list containing the names of modules which will be
- mocked in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
- def denied_modules(self) -> list[str]:
- """Return all modules that are currently denied.
- Returns:
- A list containing the names of modules which will be
- denied in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.DENY)
- def get_rdeps(self, module_name: str) -> list[str]:
- """Return a list of all modules which depend on the module ``module_name``.
- Returns:
- A list containing the names of modules which depend on ``module_name``.
- """
- if module_name in self.dependency_graph._pred.keys():
- return list(self.dependency_graph._pred[module_name].keys())
- else:
- return []
- def all_paths(self, src: str, dst: str) -> str:
- """Return a dot representation of the subgraph
- that has all paths from src to dst.
- Returns:
- A dot representation containing all paths from src to dst.
- (https://graphviz.org/doc/info/lang.html)
- """
- return self.dependency_graph.all_paths(src, dst)
- # even though these are in the standard library, we do not allow them to be
- # automatically externed since they offer a lot of system level access
- _DISALLOWED_MODULES = ["sys", "io"]
- _MOCK_IMPL = """\
- from _mock import MockedObject
- def __getattr__(attr: str):
- return MockedObject(__name__ + '.' + attr, _suppress_err=True)
- """
- def _read_file(filename: str) -> str:
- with open(filename, "rb") as f:
- b = f.read()
- return b.decode("utf-8")
|