| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379 |
- # mypy: allow-untyped-defs
- r"""Pruning methods."""
- import numbers
- from abc import ABC, abstractmethod
- from collections.abc import Iterable
- import torch
- class BasePruningMethod(ABC):
- r"""Abstract base class for creation of new pruning techniques.
- Provides a skeleton for customization requiring the overriding of methods
- such as :meth:`compute_mask` and :meth:`apply`.
- """
- _tensor_name: str
- def __call__(self, module, inputs):
- r"""Multiply the mask into original tensor and store the result.
- Multiplies the mask (stored in ``module[name + '_mask']``)
- into the original tensor (stored in ``module[name + '_orig']``)
- and stores the result into ``module[name]`` by using :meth:`apply_mask`.
- Args:
- module (nn.Module): module containing the tensor to prune
- inputs: not used.
- """
- setattr(module, self._tensor_name, self.apply_mask(module))
- @abstractmethod
- def compute_mask(self, t, default_mask):
- r"""Compute and returns a mask for the input tensor ``t``.
- Starting from a base ``default_mask`` (which should be a mask of ones
- if the tensor has not been pruned yet), generate a random mask to
- apply on top of the ``default_mask`` according to the specific pruning
- method recipe.
- Args:
- t (torch.Tensor): tensor representing the importance scores of the
- parameter to prune.
- default_mask (torch.Tensor): Base mask from previous pruning
- iterations, that need to be respected after the new mask is
- applied. Same dims as ``t``.
- Returns:
- mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
- """
- def apply_mask(self, module):
- r"""Simply handles the multiplication between the parameter being pruned and the generated mask.
- Fetches the mask and the original tensor from the module
- and returns the pruned version of the tensor.
- Args:
- module (nn.Module): module containing the tensor to prune
- Returns:
- pruned_tensor (torch.Tensor): pruned version of the input tensor
- """
- # to carry out the multiplication, the mask needs to have been computed,
- # so the pruning method must know what tensor it's operating on
- assert self._tensor_name is not None, (
- f"Module {module} has to be pruned"
- ) # this gets set in apply()
- mask = getattr(module, self._tensor_name + "_mask")
- orig = getattr(module, self._tensor_name + "_orig")
- pruned_tensor = mask.to(dtype=orig.dtype) * orig
- return pruned_tensor
- @classmethod
- def apply(cls, module, name, *args, importance_scores=None, **kwargs):
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- args: arguments passed on to a subclass of
- :class:`BasePruningMethod`
- importance_scores (torch.Tensor): tensor of importance scores (of
- same shape as module parameter) used to compute mask for pruning.
- The values in this tensor indicate the importance of the
- corresponding elements in the parameter being pruned.
- If unspecified or None, the parameter will be used in its place.
- kwargs: keyword arguments passed on to a subclass of a
- :class:`BasePruningMethod`
- """
- def _get_composite_method(cls, module, name, *args, **kwargs):
- # Check if a pruning method has already been applied to
- # `module[name]`. If so, store that in `old_method`.
- old_method = None
- found = 0
- # there should technically be only 1 hook with hook.name == name
- # assert this using `found`
- hooks_to_remove = []
- for k, hook in module._forward_pre_hooks.items():
- # if it exists, take existing thing, remove hook, then
- # go through normal thing
- if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
- old_method = hook
- hooks_to_remove.append(k)
- found += 1
- assert found <= 1, (
- f"Avoid adding multiple pruning hooks to the\
- same tensor {name} of module {module}. Use a PruningContainer."
- )
- for k in hooks_to_remove:
- del module._forward_pre_hooks[k]
- # Apply the new pruning method, either from scratch or on top of
- # the previous one.
- method = cls(*args, **kwargs) # new pruning
- # Have the pruning method remember what tensor it's been applied to
- method._tensor_name = name
- # combine `methods` with `old_method`, if `old_method` exists
- if old_method is not None: # meaning that there was a hook
- # if the hook is already a pruning container, just add the
- # new pruning method to the container
- if isinstance(old_method, PruningContainer):
- old_method.add_pruning_method(method)
- method = old_method # rename old_method --> method
- # if the hook is simply a single pruning method, create a
- # container, add the old pruning method and the new one
- elif isinstance(old_method, BasePruningMethod):
- container = PruningContainer(old_method)
- # Have the pruning method remember the name of its tensor
- # setattr(container, '_tensor_name', name)
- container.add_pruning_method(method)
- method = container # rename container --> method
- return method
- method = _get_composite_method(cls, module, name, *args, **kwargs)
- # at this point we have no forward_pre_hooks but we could have an
- # active reparametrization of the tensor if another pruning method
- # had been applied (in which case `method` would be a PruningContainer
- # and not a simple pruning method).
- # Pruning is to be applied to the module's tensor named `name`,
- # starting from the state it is found in prior to this iteration of
- # pruning. The pruning mask is calculated based on importances scores.
- orig = getattr(module, name)
- if importance_scores is not None:
- assert importance_scores.shape == orig.shape, (
- f"importance_scores should have the same shape as parameter {name} of {module}"
- )
- else:
- importance_scores = orig
- # If this is the first time pruning is applied, take care of moving
- # the original tensor to a new parameter called name + '_orig' and
- # and deleting the original parameter
- if not isinstance(method, PruningContainer):
- # copy `module[name]` to `module[name + '_orig']`
- module.register_parameter(name + "_orig", orig)
- # temporarily delete `module[name]`
- del module._parameters[name]
- default_mask = torch.ones_like(orig) # temp
- # If this is not the first time pruning is applied, all of the above
- # has been done before in a previous pruning iteration, so we're good
- # to go
- else:
- default_mask = (
- getattr(module, name + "_mask")
- .detach()
- .clone(memory_format=torch.contiguous_format)
- )
- # Use try/except because if anything goes wrong with the mask
- # computation etc., you'd want to roll back.
- try:
- # get the final mask, computed according to the specific method
- mask = method.compute_mask(importance_scores, default_mask=default_mask)
- # reparameterize by saving mask to `module[name + '_mask']`...
- module.register_buffer(name + "_mask", mask)
- # ... and the new pruned tensor to `module[name]`
- setattr(module, name, method.apply_mask(module))
- # associate the pruning method to the module via a hook to
- # compute the function before every forward() (compile by run)
- module.register_forward_pre_hook(method)
- except Exception as e:
- if not isinstance(method, PruningContainer):
- orig = getattr(module, name + "_orig")
- module.register_parameter(name, orig)
- del module._parameters[name + "_orig"]
- raise e
- return method
- def prune(self, t, default_mask=None, importance_scores=None):
- r"""Compute and returns a pruned version of input tensor ``t``.
- According to the pruning rule specified in :meth:`compute_mask`.
- Args:
- t (torch.Tensor): tensor to prune (of same dimensions as
- ``default_mask``).
- importance_scores (torch.Tensor): tensor of importance scores (of
- same shape as ``t``) used to compute mask for pruning ``t``.
- The values in this tensor indicate the importance of the
- corresponding elements in the ``t`` that is being pruned.
- If unspecified or None, the tensor ``t`` will be used in its place.
- default_mask (torch.Tensor, optional): mask from previous pruning
- iteration, if any. To be considered when determining what
- portion of the tensor that pruning should act on. If None,
- default to a mask of ones.
- Returns:
- pruned version of tensor ``t``.
- """
- if importance_scores is not None:
- assert importance_scores.shape == t.shape, (
- "importance_scores should have the same shape as tensor t"
- )
- else:
- importance_scores = t
- default_mask = default_mask if default_mask is not None else torch.ones_like(t)
- return t * self.compute_mask(importance_scores, default_mask=default_mask)
- def remove(self, module):
- r"""Remove the pruning reparameterization from a module.
- The pruned parameter named ``name`` remains permanently pruned,
- and the parameter named ``name+'_orig'`` is removed from the parameter list.
- Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.
- Note:
- Pruning itself is NOT undone or reversed!
- """
- # before removing pruning from a tensor, it has to have been applied
- assert self._tensor_name is not None, (
- f"Module {module} has to be pruned before pruning can be removed"
- ) # this gets set in apply()
- # to update module[name] to latest trained weights
- weight = self.apply_mask(module) # masked weights
- # delete and reset
- if hasattr(module, self._tensor_name):
- delattr(module, self._tensor_name)
- orig = module._parameters[self._tensor_name + "_orig"]
- orig.data = weight.data
- del module._parameters[self._tensor_name + "_orig"]
- del module._buffers[self._tensor_name + "_mask"]
- setattr(module, self._tensor_name, orig)
- class PruningContainer(BasePruningMethod):
- """Container holding a sequence of pruning methods for iterative pruning.
- Keeps track of the order in which pruning methods are applied and handles
- combining successive pruning calls.
- Accepts as argument an instance of a BasePruningMethod or an iterable of
- them.
- """
- def __init__(self, *args):
- self._pruning_methods: tuple[BasePruningMethod, ...] = ()
- if not isinstance(args, Iterable): # only 1 item
- self._tensor_name = args._tensor_name
- self.add_pruning_method(args)
- elif len(args) == 1: # only 1 item in a tuple
- self._tensor_name = args[0]._tensor_name
- self.add_pruning_method(args[0])
- else: # manual construction from list or other iterable (or no args)
- for method in args:
- self.add_pruning_method(method)
- def add_pruning_method(self, method):
- r"""Add a child pruning ``method`` to the container.
- Args:
- method (subclass of BasePruningMethod): child pruning method
- to be added to the container.
- """
- # check that we're adding a pruning method to the container
- if not isinstance(method, BasePruningMethod) and method is not None:
- raise TypeError(f"{type(method)} is not a BasePruningMethod subclass")
- elif method is not None and self._tensor_name != method._tensor_name:
- raise ValueError(
- "Can only add pruning methods acting on "
- f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
- + f" Found '{method._tensor_name}'"
- )
- # if all checks passed, add to _pruning_methods tuple
- self._pruning_methods += (method,) # type: ignore[operator]
- def __len__(self):
- return len(self._pruning_methods)
- def __iter__(self):
- return iter(self._pruning_methods)
- def __getitem__(self, idx):
- return self._pruning_methods[idx]
- def compute_mask(self, t, default_mask):
- r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.
- The new partial mask should be computed on the entries or channels
- that were not zeroed out by the ``default_mask``.
- Which portions of the tensor ``t`` the new mask will be calculated from
- depends on the ``PRUNING_TYPE`` (handled by the type handler):
- * for 'unstructured', the mask will be computed from the raveled
- list of nonmasked entries;
- * for 'structured', the mask will be computed from the nonmasked
- channels in the tensor;
- * for 'global', the mask will be computed across all entries.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- (of same dimensions as ``default_mask``).
- default_mask (torch.Tensor): mask from previous pruning iteration.
- Returns:
- mask (torch.Tensor): new mask that combines the effects
- of the ``default_mask`` and the new mask from the current
- pruning ``method`` (of same dimensions as ``default_mask`` and
- ``t``).
- """
- def _combine_masks(method, t, mask):
- r"""Combine the masks from all pruning methods and returns a new mask.
- Args:
- method (a BasePruningMethod subclass): pruning method
- currently being applied.
- t (torch.Tensor): tensor representing the parameter to prune
- (of same dimensions as mask).
- mask (torch.Tensor): mask from previous pruning iteration
- Returns:
- new_mask (torch.Tensor): new mask that combines the effects
- of the old mask and the new mask from the current
- pruning method (of same dimensions as mask and t).
- """
- new_mask = mask # start off from existing mask
- new_mask = new_mask.to(dtype=t.dtype)
- # compute a slice of t onto which the new pruning method will operate
- if method.PRUNING_TYPE == "unstructured":
- # prune entries of t where the mask is 1
- slc = mask == 1
- # for struct pruning, exclude channels that have already been
- # entirely pruned
- elif method.PRUNING_TYPE == "structured":
- if not hasattr(method, "dim"):
- raise AttributeError(
- "Pruning methods of PRUNING_TYPE "
- '"structured" need to have the attribute `dim` defined.'
- )
- # find the channels to keep by removing the ones that have been
- # zeroed out already (i.e. where sum(entries) == 0)
- n_dims = t.dim() # "is this a 2D tensor? 3D? ..."
- dim = method.dim
- # convert negative indexing
- if dim < 0:
- dim = n_dims + dim
- # if dim is still negative after subtracting it from n_dims
- if dim < 0:
- raise IndexError(
- f"Index is out of bounds for tensor with dimensions {n_dims}"
- )
- # find channels along dim = dim that aren't already tots 0ed out
- keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
- # create slice to identify what to prune
- slc = [slice(None)] * n_dims
- slc[dim] = keep_channel
- elif method.PRUNING_TYPE == "global":
- n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..."
- slc = [slice(None)] * n_dims
- else:
- raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")
- # compute the new mask on the unpruned slice of the tensor t
- if isinstance(slc, list):
- slc = tuple(slc)
- partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
- new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
- return new_mask
- method = self._pruning_methods[-1]
- mask = _combine_masks(method, t, default_mask)
- return mask
- class Identity(BasePruningMethod):
- r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""
- PRUNING_TYPE = "unstructured"
- def compute_mask(self, t, default_mask):
- mask = default_mask
- return mask
- @classmethod
- def apply(cls, module, name): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- """
- return super().apply(module, name)
- class RandomUnstructured(BasePruningMethod):
- r"""Prune (currently unpruned) units in a tensor at random.
- Args:
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- """
- PRUNING_TYPE = "unstructured"
- def __init__(self, amount):
- # Check range of validity of pruning amount
- _validate_pruning_amount_init(amount)
- self.amount = amount
- def compute_mask(self, t, default_mask):
- # Check that the amount of units to prune is not > than the number of
- # parameters in t
- tensor_size = t.nelement()
- # Compute number of units to prune: amount if int,
- # else amount * tensor_size
- nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
- # This should raise an error if the number of units to prune is larger
- # than the number of units in the tensor
- _validate_pruning_amount(nparams_toprune, tensor_size)
- mask = default_mask.clone(memory_format=torch.contiguous_format)
- if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
- prob = torch.rand_like(t)
- topk = torch.topk(prob.view(-1), k=nparams_toprune)
- mask.view(-1)[topk.indices] = 0
- return mask
- @classmethod
- def apply(cls, module, name, amount): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- """
- return super().apply(module, name, amount=amount)
- class L1Unstructured(BasePruningMethod):
- r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
- Args:
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- """
- PRUNING_TYPE = "unstructured"
- def __init__(self, amount):
- # Check range of validity of pruning amount
- _validate_pruning_amount_init(amount)
- self.amount = amount
- def compute_mask(self, t, default_mask):
- # Check that the amount of units to prune is not > than the number of
- # parameters in t
- tensor_size = t.nelement()
- # Compute number of units to prune: amount if int,
- # else amount * tensor_size
- nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
- # This should raise an error if the number of units to prune is larger
- # than the number of units in the tensor
- _validate_pruning_amount(nparams_toprune, tensor_size)
- mask = default_mask.clone(memory_format=torch.contiguous_format)
- if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
- # largest=True --> top k; largest=False --> bottom k
- # Prune the smallest k
- topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
- # topk will have .indices and .values
- mask.view(-1)[topk.indices] = 0
- return mask
- @classmethod
- def apply(cls, module, name, amount, importance_scores=None): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- importance_scores (torch.Tensor): tensor of importance scores (of same
- shape as module parameter) used to compute mask for pruning.
- The values in this tensor indicate the importance of the corresponding
- elements in the parameter being pruned.
- If unspecified or None, the module parameter will be used in its place.
- """
- return super().apply(
- module, name, amount=amount, importance_scores=importance_scores
- )
- class RandomStructured(BasePruningMethod):
- r"""Prune entire (currently unpruned) channels in a tensor at random.
- Args:
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- dim (int, optional): index of the dim along which we define
- channels to prune. Default: -1.
- """
- PRUNING_TYPE = "structured"
- def __init__(self, amount, dim=-1):
- # Check range of validity of amount
- _validate_pruning_amount_init(amount)
- self.amount = amount
- self.dim = dim
- def compute_mask(self, t, default_mask):
- r"""Compute and returns a mask for the input tensor ``t``.
- Starting from a base ``default_mask`` (which should be a mask of ones
- if the tensor has not been pruned yet), generate a random mask to
- apply on top of the ``default_mask`` by randomly zeroing out channels
- along the specified dim of the tensor.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- default_mask (torch.Tensor): Base mask from previous pruning
- iterations, that need to be respected after the new mask is
- applied. Same dims as ``t``.
- Returns:
- mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
- Raises:
- IndexError: if ``self.dim >= len(t.shape)``
- """
- # Check that tensor has structure (i.e. more than 1 dimension) such
- # that the concept of "channels" makes sense
- _validate_structured_pruning(t)
- # Check that self.dim is a valid dim to index t, else raise IndexError
- _validate_pruning_dim(t, self.dim)
- # Check that the amount of channels to prune is not > than the number of
- # channels in t along the dim to prune
- tensor_size = t.shape[self.dim]
- # Compute number of units to prune: amount if int,
- # else amount * tensor_size
- nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
- # This should raise an error if the number of units to prune is larger
- # than the number of units in the tensor
- _validate_pruning_amount(nparams_toprune, tensor_size)
- # Compute binary mask by initializing it to all 0s and then filling in
- # 1s wherever topk.indices indicates, along self.dim.
- # mask has the same shape as tensor t
- def make_mask(t, dim, nchannels, nchannels_toprune):
- # generate a random number in [0, 1] to associate to each channel
- prob = torch.rand(nchannels)
- # generate mask for each channel by 0ing out the channels that
- # got assigned the k = nchannels_toprune lowest values in prob
- threshold = torch.kthvalue(prob, k=nchannels_toprune).values
- channel_mask = prob > threshold
- mask = torch.zeros_like(t)
- slc = [slice(None)] * len(t.shape)
- slc[dim] = channel_mask
- slc = tuple(slc)
- mask[slc] = 1
- return mask
- if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
- mask = default_mask
- else:
- # apply the new structured mask on top of prior (potentially
- # unstructured) mask
- mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
- mask *= default_mask.to(dtype=mask.dtype)
- return mask
- @classmethod
- def apply(cls, module, name, amount, dim=-1): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- dim (int, optional): index of the dim along which we define
- channels to prune. Default: -1.
- """
- return super().apply(module, name, amount=amount, dim=dim)
- class LnStructured(BasePruningMethod):
- r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.
- Args:
- amount (int or float): quantity of channels to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
- entries for argument ``p`` in :func:`torch.norm`.
- dim (int, optional): index of the dim along which we define
- channels to prune. Default: -1.
- """
- PRUNING_TYPE = "structured"
- def __init__(self, amount, n, dim=-1):
- # Check range of validity of amount
- _validate_pruning_amount_init(amount)
- self.amount = amount
- self.n = n
- self.dim = dim
- def compute_mask(self, t, default_mask):
- r"""Compute and returns a mask for the input tensor ``t``.
- Starting from a base ``default_mask`` (which should be a mask of ones
- if the tensor has not been pruned yet), generate a mask to apply on
- top of the ``default_mask`` by zeroing out the channels along the
- specified dim with the lowest L\ ``n``-norm.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- default_mask (torch.Tensor): Base mask from previous pruning
- iterations, that need to be respected after the new mask is
- applied. Same dims as ``t``.
- Returns:
- mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
- Raises:
- IndexError: if ``self.dim >= len(t.shape)``
- """
- # Check that tensor has structure (i.e. more than 1 dimension) such
- # that the concept of "channels" makes sense
- _validate_structured_pruning(t)
- # Check that self.dim is a valid dim to index t, else raise IndexError
- _validate_pruning_dim(t, self.dim)
- # Check that the amount of channels to prune is not > than the number of
- # channels in t along the dim to prune
- tensor_size = t.shape[self.dim]
- # Compute number of units to prune: amount if int,
- # else amount * tensor_size
- nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
- nparams_tokeep = tensor_size - nparams_toprune
- # This should raise an error if the number of units to prune is larger
- # than the number of units in the tensor
- _validate_pruning_amount(nparams_toprune, tensor_size)
- # Structured pruning prunes entire channels so we need to know the
- # L_n norm along each channel to then find the topk based on this
- # metric
- norm = _compute_norm(t, self.n, self.dim)
- # largest=True --> top k; largest=False --> bottom k
- # Keep the largest k channels along dim=self.dim
- topk = torch.topk(norm, k=nparams_tokeep, largest=True)
- # topk will have .indices and .values
- # Compute binary mask by initializing it to all 0s and then filling in
- # 1s wherever topk.indices indicates, along self.dim.
- # mask has the same shape as tensor t
- def make_mask(t, dim, indices):
- # init mask to 0
- mask = torch.zeros_like(t)
- # e.g.: slc = [None, None, None], if len(t.shape) = 3
- slc = [slice(None)] * len(t.shape)
- # replace a None at position=dim with indices
- # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
- slc[dim] = indices
- slc = tuple(slc)
- # use slc to slice mask and replace all its entries with 1s
- # e.g.: mask[:, :, [0, 2, 3]] = 1
- mask[slc] = 1
- return mask
- if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
- mask = default_mask
- else:
- mask = make_mask(t, self.dim, topk.indices)
- mask *= default_mask.to(dtype=mask.dtype)
- return mask
- @classmethod
- def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
- entries for argument ``p`` in :func:`torch.norm`.
- dim (int): index of the dim along which we define channels to
- prune.
- importance_scores (torch.Tensor): tensor of importance scores (of same
- shape as module parameter) used to compute mask for pruning.
- The values in this tensor indicate the importance of the corresponding
- elements in the parameter being pruned.
- If unspecified or None, the module parameter will be used in its place.
- """
- return super().apply(
- module,
- name,
- amount=amount,
- n=n,
- dim=dim,
- importance_scores=importance_scores,
- )
- class CustomFromMask(BasePruningMethod):
- PRUNING_TYPE = "global"
- def __init__(self, mask):
- self.mask = mask
- def compute_mask(self, t, default_mask):
- assert default_mask.shape == self.mask.shape
- mask = default_mask * self.mask.to(dtype=default_mask.dtype)
- return mask
- @classmethod
- def apply(cls, module, name, mask): # type: ignore[override]
- r"""Add pruning on the fly and reparametrization of a tensor.
- Adds the forward pre-hook that enables pruning on the fly and
- the reparametrization of a tensor in terms of the original tensor
- and the pruning mask.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- """
- return super().apply(module, name, mask=mask)
- def identity(module, name):
- r"""Apply pruning reparametrization without pruning any units.
- Applies pruning reparametrization to the tensor corresponding to the
- parameter called ``name`` in ``module`` without actually pruning any
- units. Modifies module in place (and also return the modified module)
- by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Note:
- The mask is a tensor of ones.
- Args:
- module (nn.Module): module containing the tensor to prune.
- name (str): parameter name within ``module`` on which pruning
- will act.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> # xdoctest: +SKIP
- >>> m = prune.identity(nn.Linear(2, 3), "bias")
- >>> print(m.bias_mask)
- tensor([1., 1., 1.])
- """
- Identity.apply(module, name)
- return module
- def random_unstructured(module, name, amount):
- r"""Prune tensor by removing random (currently unpruned) units.
- Prunes tensor corresponding to parameter called ``name`` in ``module``
- by removing the specified ``amount`` of (currently unpruned) units
- selected at random.
- Modifies module in place (and also return the modified module) by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> # xdoctest: +SKIP
- >>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1)
- >>> torch.sum(m.weight_mask == 0)
- tensor(1)
- """
- RandomUnstructured.apply(module, name, amount)
- return module
- def l1_unstructured(module, name, amount, importance_scores=None):
- r"""Prune tensor by removing units with the lowest L1-norm.
- Prunes tensor corresponding to parameter called ``name`` in ``module``
- by removing the specified `amount` of (currently unpruned) units with the
- lowest L1-norm.
- Modifies module in place (and also return the modified module)
- by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- importance_scores (torch.Tensor): tensor of importance scores (of same
- shape as module parameter) used to compute mask for pruning.
- The values in this tensor indicate the importance of the corresponding
- elements in the parameter being pruned.
- If unspecified or None, the module parameter will be used in its place.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> # xdoctest: +SKIP
- >>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2)
- >>> m.state_dict().keys()
- odict_keys(['bias', 'weight_orig', 'weight_mask'])
- """
- L1Unstructured.apply(
- module, name, amount=amount, importance_scores=importance_scores
- )
- return module
- def random_structured(module, name, amount, dim):
- r"""Prune tensor by removing random channels along the specified dimension.
- Prunes tensor corresponding to parameter called ``name`` in ``module``
- by removing the specified ``amount`` of (currently unpruned) channels
- along the specified ``dim`` selected at random.
- Modifies module in place (and also return the modified module)
- by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- dim (int): index of the dim along which we define channels to prune.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> # xdoctest: +SKIP
- >>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1)
- >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
- >>> print(columns_pruned)
- 3
- """
- RandomStructured.apply(module, name, amount, dim)
- return module
- def ln_structured(module, name, amount, n, dim, importance_scores=None):
- r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.
- Prunes tensor corresponding to parameter called ``name`` in ``module``
- by removing the specified ``amount`` of (currently unpruned) channels
- along the specified ``dim`` with the lowest L\ ``n``-norm.
- Modifies module in place (and also return the modified module)
- by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- amount (int or float): quantity of parameters to prune.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
- entries for argument ``p`` in :func:`torch.norm`.
- dim (int): index of the dim along which we define channels to prune.
- importance_scores (torch.Tensor): tensor of importance scores (of same
- shape as module parameter) used to compute mask for pruning.
- The values in this tensor indicate the importance of the corresponding
- elements in the parameter being pruned.
- If unspecified or None, the module parameter will be used in its place.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> from torch.nn.utils import prune
- >>> m = prune.ln_structured(
- ... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf")
- ... )
- """
- LnStructured.apply(
- module, name, amount, n, dim, importance_scores=importance_scores
- )
- return module
- def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
- r"""
- Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``.
- Modifies modules in place by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- parameters (Iterable of (module, name) tuples): parameters of
- the model to prune in a global fashion, i.e. by aggregating all
- weights prior to deciding which ones to prune. module must be of
- type :class:`nn.Module`, and name must be a string.
- pruning_method (function): a valid pruning function from this module,
- or a custom one implemented by the user that satisfies the
- implementation guidelines and has ``PRUNING_TYPE='unstructured'``.
- importance_scores (dict): a dictionary mapping (module, name) tuples to
- the corresponding parameter's importance scores tensor. The tensor
- should be the same shape as the parameter, and is used for computing
- mask for pruning.
- If unspecified or None, the parameter will be used in place of its
- importance scores.
- kwargs: other keyword arguments such as:
- amount (int or float): quantity of parameters to prune across the
- specified parameters.
- If ``float``, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If ``int``, it represents the
- absolute number of parameters to prune.
- Raises:
- TypeError: if ``PRUNING_TYPE != 'unstructured'``
- Note:
- Since global structured pruning doesn't make much sense unless the
- norm is normalized by the size of the parameter, we now limit the
- scope of global pruning to unstructured methods.
- Examples:
- >>> from torch.nn.utils import prune
- >>> from collections import OrderedDict
- >>> net = nn.Sequential(
- ... OrderedDict(
- ... [
- ... ("first", nn.Linear(10, 4)),
- ... ("second", nn.Linear(4, 1)),
- ... ]
- ... )
- ... )
- >>> parameters_to_prune = (
- ... (net.first, "weight"),
- ... (net.second, "weight"),
- ... )
- >>> prune.global_unstructured(
- ... parameters_to_prune,
- ... pruning_method=prune.L1Unstructured,
- ... amount=10,
- ... )
- >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
- tensor(10)
- """
- # ensure parameters is a list or generator of tuples
- if not isinstance(parameters, Iterable):
- raise TypeError("global_unstructured(): parameters is not an Iterable")
- importance_scores = importance_scores if importance_scores is not None else {}
- if not isinstance(importance_scores, dict):
- raise TypeError("global_unstructured(): importance_scores must be of type dict")
- # flatten importance scores to consider them all at once in global pruning
- relevant_importance_scores = torch.nn.utils.parameters_to_vector(
- [
- importance_scores.get((module, name), getattr(module, name))
- for (module, name) in parameters
- ]
- )
- # similarly, flatten the masks (if they exist), or use a flattened vector
- # of 1s of the same dimensions as t
- default_mask = torch.nn.utils.parameters_to_vector(
- [
- getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
- for (module, name) in parameters
- ]
- )
- # use the canonical pruning methods to compute the new mask, even if the
- # parameter is now a flattened out version of `parameters`
- container = PruningContainer()
- container._tensor_name = "temp" # to make it match that of `method`
- method = pruning_method(**kwargs)
- method._tensor_name = "temp" # to make it match that of `container`
- if method.PRUNING_TYPE != "unstructured":
- raise TypeError(
- 'Only "unstructured" PRUNING_TYPE supported for '
- f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}"
- )
- container.add_pruning_method(method)
- # use the `compute_mask` method from `PruningContainer` to combine the
- # mask computed by the new method with the pre-existing mask
- final_mask = container.compute_mask(relevant_importance_scores, default_mask)
- # Pointer for slicing the mask to match the shape of each parameter
- pointer = 0
- for module, name in parameters:
- param = getattr(module, name)
- # The length of the parameter
- num_param = param.numel()
- # Slice the mask, reshape it
- param_mask = final_mask[pointer : pointer + num_param].view_as(param)
- # Assign the correct pre-computed mask to each parameter and add it
- # to the forward_pre_hooks like any other pruning method
- custom_from_mask(module, name, mask=param_mask)
- # Increment the pointer to continue slicing the final_mask
- pointer += num_param
- def custom_from_mask(module, name, mask):
- r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``.
- Modifies module in place (and also return the modified module) by:
- 1) adding a named buffer called ``name+'_mask'`` corresponding to the
- binary mask applied to the parameter ``name`` by the pruning method.
- 2) replacing the parameter ``name`` by its pruned version, while the
- original (unpruned) parameter is stored in a new parameter named
- ``name+'_orig'``.
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- mask (Tensor): binary mask to be applied to the parameter.
- Returns:
- module (nn.Module): modified (i.e. pruned) version of the input module
- Examples:
- >>> from torch.nn.utils import prune
- >>> m = prune.custom_from_mask(
- ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0])
- ... )
- >>> print(m.bias_mask)
- tensor([0., 1., 0.])
- """
- CustomFromMask.apply(module, name, mask)
- return module
- def remove(module, name):
- r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook.
- The pruned parameter named ``name`` remains permanently pruned, and the parameter
- named ``name+'_orig'`` is removed from the parameter list. Similarly,
- the buffer named ``name+'_mask'`` is removed from the buffers.
- Note:
- Pruning itself is NOT undone or reversed!
- Args:
- module (nn.Module): module containing the tensor to prune
- name (str): parameter name within ``module`` on which pruning
- will act.
- Examples:
- >>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2)
- >>> m = remove(m, name="weight")
- """
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
- hook.remove(module)
- del module._forward_pre_hooks[k]
- return module
- raise ValueError(
- f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed"
- )
- def is_pruned(module):
- r"""Check if a module is pruned by looking for pruning pre-hooks.
- Check whether ``module`` is pruned by looking for
- ``forward_pre_hooks`` in its modules that inherit from the
- :class:`BasePruningMethod`.
- Args:
- module (nn.Module): object that is either pruned or unpruned
- Returns:
- binary answer to whether ``module`` is pruned.
- Examples:
- >>> from torch.nn.utils import prune
- >>> m = nn.Linear(5, 7)
- >>> print(prune.is_pruned(m))
- False
- >>> prune.random_unstructured(m, name="weight", amount=0.2)
- >>> print(prune.is_pruned(m))
- True
- """
- for _, submodule in module.named_modules():
- for hook in submodule._forward_pre_hooks.values():
- if isinstance(hook, BasePruningMethod):
- return True
- return False
- def _validate_pruning_amount_init(amount):
- r"""Validate helper to check the range of amount at init.
- Args:
- amount (int or float): quantity of parameters to prune.
- If float, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If int, it represents the
- absolute number of parameters to prune.
- Raises:
- ValueError: if amount is a float not in [0, 1], or if it's a negative
- integer.
- TypeError: if amount is neither a float nor an integer.
- Note:
- This does not take into account the number of parameters in the
- tensor to be pruned, which is known only at prune.
- """
- if not isinstance(amount, numbers.Real):
- raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.")
- if (isinstance(amount, numbers.Integral) and amount < 0) or (
- not isinstance(amount, numbers.Integral) # so it's a float
- and (float(amount) > 1.0 or float(amount) < 0.0)
- ):
- raise ValueError(
- f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer"
- )
- def _validate_pruning_amount(amount, tensor_size):
- r"""Validate that the pruning amount is meaningful wrt to the size of the data.
- Validation helper to check that the amount of parameters to prune
- is meaningful wrt to the size of the data (`tensor_size`).
- Args:
- amount (int or float): quantity of parameters to prune.
- If float, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If int, it represents the
- absolute number of parameters to prune.
- tensor_size (int): absolute number of parameters in the tensor
- to prune.
- """
- # TODO: consider removing this check and allowing users to specify
- # a number of units to prune that is greater than the number of units
- # left to prune. In this case, the tensor will just be fully pruned.
- if isinstance(amount, numbers.Integral) and amount > tensor_size:
- raise ValueError(
- f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}"
- )
- def _validate_structured_pruning(t):
- r"""Validate that the tensor to be pruned is at least 2-Dimensional.
- Validation helper to check that the tensor to be pruned is multi-
- dimensional, such that the concept of "channels" is well-defined.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- Raises:
- ValueError: if the tensor `t` is not at least 2D.
- """
- shape = t.shape
- if len(shape) <= 1:
- raise ValueError(
- "Structured pruning can only be applied to "
- "multidimensional tensors. Found tensor of shape "
- f"{shape} with {len(shape)} dims"
- )
- def _compute_nparams_toprune(amount, tensor_size):
- r"""Convert the pruning amount from a percentage to absolute value.
- Since amount can be expressed either in absolute value or as a
- percentage of the number of units/channels in a tensor, this utility
- function converts the percentage to absolute value to standardize
- the handling of pruning.
- Args:
- amount (int or float): quantity of parameters to prune.
- If float, should be between 0.0 and 1.0 and represent the
- fraction of parameters to prune. If int, it represents the
- absolute number of parameters to prune.
- tensor_size (int): absolute number of parameters in the tensor
- to prune.
- Returns:
- int: the number of units to prune in the tensor
- """
- # incorrect type already checked in _validate_pruning_amount_init
- if isinstance(amount, numbers.Integral):
- return amount
- else:
- return round(amount * tensor_size)
- def _validate_pruning_dim(t, dim):
- r"""Validate that the pruning dimension is within the bounds of the tensor dimension.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- dim (int): index of the dim along which we define channels to prune
- """
- if dim >= t.dim():
- raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}")
- def _compute_norm(t, n, dim):
- r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension.
- The L_n-norm will be computed across all entries in tensor `t` along all dimension
- except for the one identified by dim.
- Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),
- then norm will have Size [4], and each entry will represent the
- `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.
- Args:
- t (torch.Tensor): tensor representing the parameter to prune
- n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
- entries for argument p in torch.norm
- dim (int): dim identifying the channels to prune
- Returns:
- norm (torch.Tensor): L_n norm computed across all dimensions except
- for `dim`. By construction, `norm.shape = t.shape[-1]`.
- """
- # dims = all axes, except for the one identified by `dim`
- dims = list(range(t.dim()))
- # convert negative indexing
- if dim < 0:
- dim = dims[dim]
- dims.remove(dim)
- norm = torch.norm(t, p=n, dim=dims)
- return norm
|