lenet.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import nn
  16. __all__ = []
  17. class LeNet(nn.Layer):
  18. """LeNet model from
  19. `"Gradient-based learning applied to document recognition" <https://ieeexplore.ieee.org/document/726791>`_.
  20. Args:
  21. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer
  22. will not be defined. Default: 10.
  23. Returns:
  24. :ref:`api_paddle_nn_Layer`. An instance of LeNet model.
  25. Examples:
  26. .. code-block:: python
  27. >>> import paddle
  28. >>> from paddle.vision.models import LeNet
  29. >>> model = LeNet()
  30. >>> x = paddle.rand([1, 1, 28, 28])
  31. >>> out = model(x)
  32. >>> print(out.shape)
  33. [1, 10]
  34. """
  35. def __init__(self, num_classes=10):
  36. super().__init__()
  37. self.num_classes = num_classes
  38. self.features = nn.Sequential(
  39. nn.Conv2D(1, 6, 3, stride=1, padding=1),
  40. nn.ReLU(),
  41. nn.MaxPool2D(2, 2),
  42. nn.Conv2D(6, 16, 5, stride=1, padding=0),
  43. nn.ReLU(),
  44. nn.MaxPool2D(2, 2),
  45. )
  46. if num_classes > 0:
  47. self.fc = nn.Sequential(
  48. nn.Linear(400, 120),
  49. nn.Linear(120, 84),
  50. nn.Linear(84, num_classes),
  51. )
  52. def forward(self, inputs):
  53. x = self.features(inputs)
  54. if self.num_classes > 0:
  55. x = paddle.flatten(x, 1)
  56. x = self.fc(x)
  57. return x