sparse.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. from torch import Tensor
  5. from torch.nn import functional as F, init
  6. from torch.nn.parameter import Parameter
  7. from .module import Module
  8. __all__ = ["Embedding", "EmbeddingBag"]
  9. class Embedding(Module):
  10. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  11. This module is often used to store word embeddings and retrieve them using indices.
  12. The input to the module is a list of indices, and the output is the corresponding
  13. word embeddings.
  14. Args:
  15. num_embeddings (int): size of the dictionary of embeddings
  16. embedding_dim (int): the size of each embedding vector
  17. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  18. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  19. i.e. it remains as a fixed "pad". For a newly constructed Embedding,
  20. the embedding vector at :attr:`padding_idx` will default to all zeros,
  21. but can be updated to another value to be used as the padding vector.
  22. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  23. is renormalized to have norm :attr:`max_norm`.
  24. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  25. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
  26. the words in the mini-batch. Default ``False``.
  27. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
  28. See Notes for more details regarding sparse gradients.
  29. Attributes:
  30. weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
  31. initialized from :math:`\mathcal{N}(0, 1)`
  32. Shape:
  33. - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
  34. - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
  35. .. note::
  36. Keep in mind that only a limited number of optimizers support
  37. sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
  38. :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
  39. .. note::
  40. When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
  41. :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
  42. modified in-place, performing a differentiable operation on ``Embedding.weight`` before
  43. calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
  44. :attr:`max_norm` is not ``None``. For example::
  45. n, d, m = 3, 5, 7
  46. embedding = nn.Embedding(n, d, max_norm=1.0)
  47. W = torch.randn((m, d), requires_grad=True)
  48. idx = torch.tensor([1, 2])
  49. a = (
  50. embedding.weight.clone() @ W.t()
  51. ) # weight must be cloned for this to be differentiable
  52. b = embedding(idx) @ W.t() # modifies weight in-place
  53. out = a.unsqueeze(0) + b.unsqueeze(1)
  54. loss = out.sigmoid().prod()
  55. loss.backward()
  56. Examples::
  57. >>> # an Embedding module containing 10 tensors of size 3
  58. >>> embedding = nn.Embedding(10, 3)
  59. >>> # a batch of 2 samples of 4 indices each
  60. >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
  61. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  62. >>> embedding(input)
  63. tensor([[[-0.0251, -1.6902, 0.7172],
  64. [-0.6431, 0.0748, 0.6969],
  65. [ 1.4970, 1.3448, -0.9685],
  66. [-0.3677, -2.7265, -0.1685]],
  67. [[ 1.4970, 1.3448, -0.9685],
  68. [ 0.4362, -0.4004, 0.9400],
  69. [-0.6431, 0.0748, 0.6969],
  70. [ 0.9124, -2.3616, 1.1151]]])
  71. >>> # example with padding_idx
  72. >>> embedding = nn.Embedding(10, 3, padding_idx=0)
  73. >>> input = torch.LongTensor([[0, 2, 0, 5]])
  74. >>> embedding(input)
  75. tensor([[[ 0.0000, 0.0000, 0.0000],
  76. [ 0.1535, -2.0309, 0.9315],
  77. [ 0.0000, 0.0000, 0.0000],
  78. [-0.1655, 0.9897, 0.0635]]])
  79. >>> # example of changing `pad` vector
  80. >>> padding_idx = 0
  81. >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
  82. >>> embedding.weight
  83. Parameter containing:
  84. tensor([[ 0.0000, 0.0000, 0.0000],
  85. [-0.7895, -0.7089, -0.0364],
  86. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  87. >>> with torch.no_grad():
  88. ... embedding.weight[padding_idx] = torch.ones(3)
  89. >>> embedding.weight
  90. Parameter containing:
  91. tensor([[ 1.0000, 1.0000, 1.0000],
  92. [-0.7895, -0.7089, -0.0364],
  93. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  94. """
  95. __constants__ = [
  96. "num_embeddings",
  97. "embedding_dim",
  98. "padding_idx",
  99. "max_norm",
  100. "norm_type",
  101. "scale_grad_by_freq",
  102. "sparse",
  103. ]
  104. num_embeddings: int
  105. embedding_dim: int
  106. padding_idx: Optional[int]
  107. max_norm: Optional[float]
  108. norm_type: float
  109. scale_grad_by_freq: bool
  110. weight: Tensor
  111. freeze: bool
  112. sparse: bool
  113. def __init__(
  114. self,
  115. num_embeddings: int,
  116. embedding_dim: int,
  117. padding_idx: Optional[int] = None,
  118. max_norm: Optional[float] = None,
  119. norm_type: float = 2.0,
  120. scale_grad_by_freq: bool = False,
  121. sparse: bool = False,
  122. _weight: Optional[Tensor] = None,
  123. _freeze: bool = False,
  124. device=None,
  125. dtype=None,
  126. ) -> None:
  127. factory_kwargs = {"device": device, "dtype": dtype}
  128. super().__init__()
  129. self.num_embeddings = num_embeddings
  130. self.embedding_dim = embedding_dim
  131. if padding_idx is not None:
  132. if padding_idx > 0:
  133. assert padding_idx < self.num_embeddings, (
  134. "Padding_idx must be within num_embeddings"
  135. )
  136. elif padding_idx < 0:
  137. assert padding_idx >= -self.num_embeddings, (
  138. "Padding_idx must be within num_embeddings"
  139. )
  140. padding_idx = self.num_embeddings + padding_idx
  141. self.padding_idx = padding_idx
  142. self.max_norm = max_norm
  143. self.norm_type = norm_type
  144. self.scale_grad_by_freq = scale_grad_by_freq
  145. if _weight is None:
  146. self.weight = Parameter(
  147. torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
  148. requires_grad=not _freeze,
  149. )
  150. self.reset_parameters()
  151. else:
  152. assert list(_weight.shape) == [
  153. num_embeddings,
  154. embedding_dim,
  155. ], "Shape of weight does not match num_embeddings and embedding_dim"
  156. self.weight = Parameter(_weight, requires_grad=not _freeze)
  157. self.sparse = sparse
  158. def reset_parameters(self) -> None:
  159. init.normal_(self.weight)
  160. self._fill_padding_idx_with_zero()
  161. def _fill_padding_idx_with_zero(self) -> None:
  162. if self.padding_idx is not None:
  163. with torch.no_grad():
  164. self.weight[self.padding_idx].fill_(0)
  165. def forward(self, input: Tensor) -> Tensor:
  166. return F.embedding(
  167. input,
  168. self.weight,
  169. self.padding_idx,
  170. self.max_norm,
  171. self.norm_type,
  172. self.scale_grad_by_freq,
  173. self.sparse,
  174. )
  175. def extra_repr(self) -> str:
  176. s = "{num_embeddings}, {embedding_dim}"
  177. if self.padding_idx is not None:
  178. s += ", padding_idx={padding_idx}"
  179. if self.max_norm is not None:
  180. s += ", max_norm={max_norm}"
  181. if self.norm_type != 2:
  182. s += ", norm_type={norm_type}"
  183. if self.scale_grad_by_freq is not False:
  184. s += ", scale_grad_by_freq={scale_grad_by_freq}"
  185. if self.sparse is not False:
  186. s += ", sparse=True"
  187. return s.format(**self.__dict__)
  188. @classmethod
  189. def from_pretrained(
  190. cls,
  191. embeddings,
  192. freeze=True,
  193. padding_idx=None,
  194. max_norm=None,
  195. norm_type=2.0,
  196. scale_grad_by_freq=False,
  197. sparse=False,
  198. ):
  199. r"""Create Embedding instance from given 2-dimensional FloatTensor.
  200. Args:
  201. embeddings (Tensor): FloatTensor containing weights for the Embedding.
  202. First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
  203. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  204. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
  205. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  206. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  207. i.e. it remains as a fixed "pad".
  208. max_norm (float, optional): See module initialization documentation.
  209. norm_type (float, optional): See module initialization documentation. Default ``2``.
  210. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  211. sparse (bool, optional): See module initialization documentation.
  212. Examples::
  213. >>> # FloatTensor containing pretrained weights
  214. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  215. >>> embedding = nn.Embedding.from_pretrained(weight)
  216. >>> # Get embeddings for index 1
  217. >>> input = torch.LongTensor([1])
  218. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  219. >>> embedding(input)
  220. tensor([[ 4.0000, 5.1000, 6.3000]])
  221. """
  222. assert embeddings.dim() == 2, (
  223. "Embeddings parameter is expected to be 2-dimensional"
  224. )
  225. rows, cols = embeddings.shape
  226. embedding = cls(
  227. num_embeddings=rows,
  228. embedding_dim=cols,
  229. _weight=embeddings,
  230. _freeze=freeze,
  231. padding_idx=padding_idx,
  232. max_norm=max_norm,
  233. norm_type=norm_type,
  234. scale_grad_by_freq=scale_grad_by_freq,
  235. sparse=sparse,
  236. )
  237. return embedding
  238. class EmbeddingBag(Module):
  239. r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings.
  240. For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`,
  241. and with 2D inputs, this class
  242. * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``,
  243. * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``,
  244. * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``.
  245. However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
  246. operations.
  247. EmbeddingBag also supports per-sample weights as an argument to the forward
  248. pass. This scales the output of the Embedding before performing a weighted
  249. reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
  250. only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
  251. :attr:`per_sample_weights`.
  252. Args:
  253. num_embeddings (int): size of the dictionary of embeddings
  254. embedding_dim (int): the size of each embedding vector
  255. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  256. is renormalized to have norm :attr:`max_norm`.
  257. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  258. scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
  259. the words in the mini-batch. Default ``False``.
  260. Note: this option is not supported when ``mode="max"``.
  261. mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
  262. ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
  263. into consideration. ``"mean"`` computes the average of the values
  264. in the bag, ``"max"`` computes the max value over each bag.
  265. Default: ``"mean"``
  266. sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
  267. Notes for more details regarding sparse gradients. Note: this option is not
  268. supported when ``mode="max"``.
  269. include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element
  270. is equivalent to the size of `indices`. This matches the CSR format.
  271. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
  272. gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
  273. during training, i.e. it remains as a fixed "pad". For a newly constructed
  274. EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all
  275. zeros, but can be updated to another value to be used as the padding vector.
  276. Note that the embedding vector at :attr:`padding_idx` is excluded from the
  277. reduction.
  278. Attributes:
  279. weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
  280. initialized from :math:`\mathcal{N}(0, 1)`.
  281. Examples::
  282. >>> # an EmbeddingBag module containing 10 tensors of size 3
  283. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
  284. >>> # a batch of 2 samples of 4 indices each
  285. >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
  286. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  287. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  288. >>> embedding_sum(input, offsets)
  289. tensor([[-0.8861, -5.4350, -0.0523],
  290. [ 1.1306, -2.5798, -1.0044]])
  291. >>> # Example with padding_idx
  292. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
  293. >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
  294. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  295. >>> embedding_sum(input, offsets)
  296. tensor([[ 0.0000, 0.0000, 0.0000],
  297. [-0.7082, 3.2145, -2.6251]])
  298. >>> # An EmbeddingBag can be loaded from an Embedding like so
  299. >>> embedding = nn.Embedding(10, 3, padding_idx=2)
  300. >>> embedding_sum = nn.EmbeddingBag.from_pretrained(
  301. embedding.weight,
  302. padding_idx=embedding.padding_idx,
  303. mode='sum')
  304. """
  305. __constants__ = [
  306. "num_embeddings",
  307. "embedding_dim",
  308. "max_norm",
  309. "norm_type",
  310. "scale_grad_by_freq",
  311. "mode",
  312. "sparse",
  313. "include_last_offset",
  314. "padding_idx",
  315. ]
  316. num_embeddings: int
  317. embedding_dim: int
  318. max_norm: Optional[float]
  319. norm_type: float
  320. scale_grad_by_freq: bool
  321. weight: Tensor
  322. mode: str
  323. sparse: bool
  324. include_last_offset: bool
  325. padding_idx: Optional[int]
  326. def __init__(
  327. self,
  328. num_embeddings: int,
  329. embedding_dim: int,
  330. max_norm: Optional[float] = None,
  331. norm_type: float = 2.0,
  332. scale_grad_by_freq: bool = False,
  333. mode: str = "mean",
  334. sparse: bool = False,
  335. _weight: Optional[Tensor] = None,
  336. include_last_offset: bool = False,
  337. padding_idx: Optional[int] = None,
  338. device=None,
  339. dtype=None,
  340. ) -> None:
  341. factory_kwargs = {"device": device, "dtype": dtype}
  342. super().__init__()
  343. self.num_embeddings = num_embeddings
  344. self.embedding_dim = embedding_dim
  345. self.max_norm = max_norm
  346. self.norm_type = norm_type
  347. self.scale_grad_by_freq = scale_grad_by_freq
  348. if padding_idx is not None:
  349. if padding_idx > 0:
  350. assert padding_idx < self.num_embeddings, (
  351. "padding_idx must be within num_embeddings"
  352. )
  353. elif padding_idx < 0:
  354. assert padding_idx >= -self.num_embeddings, (
  355. "padding_idx must be within num_embeddings"
  356. )
  357. padding_idx = self.num_embeddings + padding_idx
  358. self.padding_idx = padding_idx
  359. if _weight is None:
  360. self.weight = Parameter(
  361. torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
  362. )
  363. self.reset_parameters()
  364. else:
  365. assert list(_weight.shape) == [
  366. num_embeddings,
  367. embedding_dim,
  368. ], "Shape of weight does not match num_embeddings and embedding_dim"
  369. self.weight = Parameter(_weight)
  370. self.mode = mode
  371. self.sparse = sparse
  372. self.include_last_offset = include_last_offset
  373. def reset_parameters(self) -> None:
  374. init.normal_(self.weight)
  375. self._fill_padding_idx_with_zero()
  376. def _fill_padding_idx_with_zero(self) -> None:
  377. if self.padding_idx is not None:
  378. with torch.no_grad():
  379. self.weight[self.padding_idx].fill_(0)
  380. def forward(
  381. self,
  382. input: Tensor,
  383. offsets: Optional[Tensor] = None,
  384. per_sample_weights: Optional[Tensor] = None,
  385. ) -> Tensor:
  386. """Forward pass of EmbeddingBag.
  387. Args:
  388. input (Tensor): Tensor containing bags of indices into the embedding matrix.
  389. offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
  390. the starting index position of each bag (sequence) in :attr:`input`.
  391. per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
  392. to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
  393. must have exactly the same shape as input and is treated as having the same
  394. :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
  395. Returns:
  396. Tensor output shape of `(B, embedding_dim)`.
  397. .. note::
  398. A few notes about ``input`` and ``offsets``:
  399. - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
  400. - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
  401. each of fixed length ``N``, and this will return ``B`` values aggregated in a way
  402. depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
  403. - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
  404. multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the
  405. starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
  406. :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
  407. returned vectors filled by zeros.
  408. """
  409. return F.embedding_bag(
  410. input,
  411. self.weight,
  412. offsets,
  413. self.max_norm,
  414. self.norm_type,
  415. self.scale_grad_by_freq,
  416. self.mode,
  417. self.sparse,
  418. per_sample_weights,
  419. self.include_last_offset,
  420. self.padding_idx,
  421. )
  422. def extra_repr(self) -> str:
  423. s = "{num_embeddings}, {embedding_dim}"
  424. if self.max_norm is not None:
  425. s += ", max_norm={max_norm}"
  426. if self.norm_type != 2:
  427. s += ", norm_type={norm_type}"
  428. if self.scale_grad_by_freq is not False:
  429. s += ", scale_grad_by_freq={scale_grad_by_freq}"
  430. s += ", mode={mode}"
  431. if self.padding_idx is not None:
  432. s += ", padding_idx={padding_idx}"
  433. return s.format(**{k: repr(v) for k, v in self.__dict__.items()})
  434. @classmethod
  435. def from_pretrained(
  436. cls,
  437. embeddings: Tensor,
  438. freeze: bool = True,
  439. max_norm: Optional[float] = None,
  440. norm_type: float = 2.0,
  441. scale_grad_by_freq: bool = False,
  442. mode: str = "mean",
  443. sparse: bool = False,
  444. include_last_offset: bool = False,
  445. padding_idx: Optional[int] = None,
  446. ) -> "EmbeddingBag":
  447. r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor.
  448. Args:
  449. embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
  450. First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
  451. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  452. Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
  453. max_norm (float, optional): See module initialization documentation. Default: ``None``
  454. norm_type (float, optional): See module initialization documentation. Default ``2``.
  455. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  456. mode (str, optional): See module initialization documentation. Default: ``"mean"``
  457. sparse (bool, optional): See module initialization documentation. Default: ``False``.
  458. include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
  459. padding_idx (int, optional): See module initialization documentation. Default: ``None``.
  460. Examples::
  461. >>> # FloatTensor containing pretrained weights
  462. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  463. >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
  464. >>> # Get embeddings for index 1
  465. >>> input = torch.LongTensor([[1, 0]])
  466. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  467. >>> embeddingbag(input)
  468. tensor([[ 2.5000, 3.7000, 4.6500]])
  469. """
  470. assert embeddings.dim() == 2, (
  471. "Embeddings parameter is expected to be 2-dimensional"
  472. )
  473. rows, cols = embeddings.shape
  474. embeddingbag = cls(
  475. num_embeddings=rows,
  476. embedding_dim=cols,
  477. _weight=embeddings,
  478. max_norm=max_norm,
  479. norm_type=norm_type,
  480. scale_grad_by_freq=scale_grad_by_freq,
  481. mode=mode,
  482. sparse=sparse,
  483. include_last_offset=include_last_offset,
  484. padding_idx=padding_idx,
  485. )
  486. embeddingbag.weight.requires_grad = not freeze
  487. return embeddingbag