gcn.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/gcn.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. class BatchNorm1D(nn.BatchNorm1D):
  25. def __init__(
  26. self,
  27. num_features,
  28. eps=1e-05,
  29. momentum=0.1,
  30. affine=True,
  31. track_running_stats=True,
  32. ):
  33. momentum = 1 - momentum
  34. weight_attr = None
  35. bias_attr = None
  36. if not affine:
  37. weight_attr = paddle.ParamAttr(learning_rate=0.0)
  38. bias_attr = paddle.ParamAttr(learning_rate=0.0)
  39. super().__init__(
  40. num_features,
  41. momentum=momentum,
  42. epsilon=eps,
  43. weight_attr=weight_attr,
  44. bias_attr=bias_attr,
  45. use_global_stats=track_running_stats,
  46. )
  47. class MeanAggregator(nn.Layer):
  48. def forward(self, features, A):
  49. x = paddle.bmm(A, features)
  50. return x
  51. class GraphConv(nn.Layer):
  52. def __init__(self, in_dim, out_dim):
  53. super().__init__()
  54. self.in_dim = in_dim
  55. self.out_dim = out_dim
  56. self.weight = self.create_parameter(
  57. [in_dim * 2, out_dim], default_initializer=nn.initializer.XavierUniform()
  58. )
  59. self.bias = self.create_parameter(
  60. [out_dim],
  61. is_bias=True,
  62. default_initializer=nn.initializer.Assign([0] * out_dim),
  63. )
  64. self.aggregator = MeanAggregator()
  65. def forward(self, features, A):
  66. b, n, d = features.shape
  67. assert d == self.in_dim
  68. agg_feats = self.aggregator(features, A)
  69. cat_feats = paddle.concat([features, agg_feats], axis=2)
  70. out = paddle.einsum("bnd,df->bnf", cat_feats, self.weight)
  71. out = F.relu(out + self.bias)
  72. return out
  73. class GCN(nn.Layer):
  74. def __init__(self, feat_len):
  75. super(GCN, self).__init__()
  76. self.bn0 = BatchNorm1D(feat_len, affine=False)
  77. self.conv1 = GraphConv(feat_len, 512)
  78. self.conv2 = GraphConv(512, 256)
  79. self.conv3 = GraphConv(256, 128)
  80. self.conv4 = GraphConv(128, 64)
  81. self.classifier = nn.Sequential(
  82. nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2)
  83. )
  84. def forward(self, x, A, knn_inds):
  85. num_local_graphs, num_max_nodes, feat_len = x.shape
  86. x = x.reshape([-1, feat_len])
  87. x = self.bn0(x)
  88. x = x.reshape([num_local_graphs, num_max_nodes, feat_len])
  89. x = self.conv1(x, A)
  90. x = self.conv2(x, A)
  91. x = self.conv3(x, A)
  92. x = self.conv4(x, A)
  93. k = knn_inds.shape[-1]
  94. mid_feat_len = x.shape[-1]
  95. edge_feat = paddle.zeros([num_local_graphs, k, mid_feat_len])
  96. for graph_ind in range(num_local_graphs):
  97. edge_feat[graph_ind, :, :] = x[graph_ind][
  98. paddle.to_tensor(knn_inds[graph_ind])
  99. ]
  100. edge_feat = edge_feat.reshape([-1, mid_feat_len])
  101. pred = self.classifier(edge_feat)
  102. return pred