Registry.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #ifndef C10_UTIL_REGISTRY_H_
  2. #define C10_UTIL_REGISTRY_H_
  3. /**
  4. * Simple registry implementation that uses static variables to
  5. * register object creators during program initialization time.
  6. */
  7. // NB: This Registry works poorly when you have other namespaces.
  8. // Make all macro invocations from inside the at namespace.
  9. #include <cstdio>
  10. #include <cstdlib>
  11. #include <functional>
  12. #include <memory>
  13. #include <mutex>
  14. #include <stdexcept>
  15. #include <string>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include <c10/macros/Export.h>
  19. #include <c10/macros/Macros.h>
  20. #include <c10/util/Type.h>
  21. namespace c10 {
  22. template <typename KeyType>
  23. inline std::string KeyStrRepr(const KeyType& /*key*/) {
  24. return "[key type printing not supported]";
  25. }
  26. template <>
  27. inline std::string KeyStrRepr(const std::string& key) {
  28. return key;
  29. }
  30. enum RegistryPriority {
  31. REGISTRY_FALLBACK = 1,
  32. REGISTRY_DEFAULT = 2,
  33. REGISTRY_PREFERRED = 3,
  34. };
  35. /**
  36. * @brief A template class that allows one to register classes by keys.
  37. *
  38. * The keys are usually a std::string specifying the name, but can be anything
  39. * that can be used in a std::map.
  40. *
  41. * You should most likely not use the Registry class explicitly, but use the
  42. * helper macros below to declare specific registries as well as registering
  43. * objects.
  44. */
  45. template <class SrcType, class ObjectPtrType, class... Args>
  46. class Registry {
  47. public:
  48. typedef std::function<ObjectPtrType(Args...)> Creator;
  49. Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
  50. ~Registry() = default;
  51. void Register(
  52. const SrcType& key,
  53. Creator creator,
  54. const RegistryPriority priority = REGISTRY_DEFAULT) {
  55. std::lock_guard<std::mutex> lock(register_mutex_);
  56. // The if statement below is essentially the same as the following line:
  57. // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
  58. // << " registered twice.";
  59. // However, TORCH_CHECK_EQ depends on google logging, and since registration
  60. // is carried out at static initialization time, we do not want to have an
  61. // explicit dependency on glog's initialization function.
  62. if (registry_.count(key) != 0) {
  63. auto cur_priority = priority_[key];
  64. if (priority > cur_priority) {
  65. #ifdef DEBUG
  66. std::string warn_msg =
  67. "Overwriting already registered item for key " + KeyStrRepr(key);
  68. fprintf(stderr, "%s\n", warn_msg.c_str());
  69. #endif
  70. registry_[key] = creator;
  71. priority_[key] = priority;
  72. } else if (priority == cur_priority) {
  73. std::string err_msg =
  74. "Key already registered with the same priority: " + KeyStrRepr(key);
  75. fprintf(stderr, "%s\n", err_msg.c_str());
  76. if (terminate_) {
  77. std::exit(1);
  78. } else {
  79. throw std::runtime_error(err_msg);
  80. }
  81. } else if (warning_) {
  82. std::string warn_msg =
  83. "Higher priority item already registered, skipping registration of " +
  84. KeyStrRepr(key);
  85. fprintf(stderr, "%s\n", warn_msg.c_str());
  86. }
  87. } else {
  88. registry_[key] = creator;
  89. priority_[key] = priority;
  90. }
  91. }
  92. void Register(
  93. const SrcType& key,
  94. Creator creator,
  95. const std::string& help_msg,
  96. const RegistryPriority priority = REGISTRY_DEFAULT) {
  97. Register(key, creator, priority);
  98. help_message_[key] = help_msg;
  99. }
  100. inline bool Has(const SrcType& key) {
  101. return (registry_.count(key) != 0);
  102. }
  103. ObjectPtrType Create(const SrcType& key, Args... args) {
  104. auto it = registry_.find(key);
  105. if (it == registry_.end()) {
  106. // Returns nullptr if the key is not registered.
  107. return nullptr;
  108. }
  109. return it->second(args...);
  110. }
  111. /**
  112. * Returns the keys currently registered as a std::vector.
  113. */
  114. std::vector<SrcType> Keys() const {
  115. std::vector<SrcType> keys;
  116. keys.reserve(registry_.size());
  117. for (const auto& it : registry_) {
  118. keys.push_back(it.first);
  119. }
  120. return keys;
  121. }
  122. inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
  123. return help_message_;
  124. }
  125. const char* HelpMessage(const SrcType& key) const {
  126. auto it = help_message_.find(key);
  127. if (it == help_message_.end()) {
  128. return nullptr;
  129. }
  130. return it->second.c_str();
  131. }
  132. // Used for testing, if terminate is unset, Registry throws instead of
  133. // calling std::exit
  134. void SetTerminate(bool terminate) {
  135. terminate_ = terminate;
  136. }
  137. C10_DISABLE_COPY_AND_ASSIGN(Registry);
  138. Registry(Registry&&) = delete;
  139. Registry& operator=(Registry&&) = delete;
  140. private:
  141. std::unordered_map<SrcType, Creator> registry_;
  142. std::unordered_map<SrcType, RegistryPriority> priority_;
  143. bool terminate_{true};
  144. const bool warning_;
  145. std::unordered_map<SrcType, std::string> help_message_;
  146. std::mutex register_mutex_;
  147. };
  148. template <class SrcType, class ObjectPtrType, class... Args>
  149. class Registerer {
  150. public:
  151. explicit Registerer(
  152. const SrcType& key,
  153. Registry<SrcType, ObjectPtrType, Args...>* registry,
  154. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  155. const std::string& help_msg = "") {
  156. registry->Register(key, creator, help_msg);
  157. }
  158. explicit Registerer(
  159. const SrcType& key,
  160. const RegistryPriority priority,
  161. Registry<SrcType, ObjectPtrType, Args...>* registry,
  162. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  163. const std::string& help_msg = "") {
  164. registry->Register(key, creator, help_msg, priority);
  165. }
  166. template <class DerivedType>
  167. static ObjectPtrType DefaultCreator(Args... args) {
  168. return ObjectPtrType(new DerivedType(args...));
  169. }
  170. };
  171. /**
  172. * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
  173. * declaration, as well as creating a convenient typename for its corresponding
  174. * registerer.
  175. */
  176. // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
  177. // as import and DEFINE as export, because these registry macros will be used
  178. // in downstream shared libraries as well, and one cannot use *_API - the API
  179. // macro will be defined on a per-shared-library basis. Semantically, when one
  180. // declares a typed registry it is always going to be IMPORT, and when one
  181. // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
  182. // the instantiation unit is always going to be exported.
  183. //
  184. // The only unique condition is when in the same file one does DECLARE and
  185. // DEFINE - in Windows compilers, this generates a warning that dllimport and
  186. // dllexport are mixed, but the warning is fine and linker will be properly
  187. // exporting the symbol. Same thing happens in the gflags flag declaration and
  188. // definition caes.
  189. #define C10_DECLARE_TYPED_REGISTRY( \
  190. RegistryName, SrcType, ObjectType, PtrType, ...) \
  191. C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  192. RegistryName(); \
  193. typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
  194. Registerer##RegistryName
  195. #define TORCH_DECLARE_TYPED_REGISTRY( \
  196. RegistryName, SrcType, ObjectType, PtrType, ...) \
  197. TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  198. RegistryName(); \
  199. typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
  200. Registerer##RegistryName
  201. #define C10_DEFINE_TYPED_REGISTRY( \
  202. RegistryName, SrcType, ObjectType, PtrType, ...) \
  203. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  204. RegistryName() { \
  205. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  206. registry = new ::c10:: \
  207. Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
  208. return registry; \
  209. }
  210. #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  211. RegistryName, SrcType, ObjectType, PtrType, ...) \
  212. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  213. RegistryName() { \
  214. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  215. registry = \
  216. new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
  217. false); \
  218. return registry; \
  219. }
  220. // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
  221. // creator with comma in its templated arguments.
  222. #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
  223. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  224. key, RegistryName(), ##__VA_ARGS__);
  225. #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  226. RegistryName, key, priority, ...) \
  227. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  228. key, priority, RegistryName(), ##__VA_ARGS__);
  229. #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
  230. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  231. key, \
  232. RegistryName(), \
  233. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  234. ::c10::demangle_type<__VA_ARGS__>());
  235. #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  236. RegistryName, key, priority, ...) \
  237. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  238. key, \
  239. priority, \
  240. RegistryName(), \
  241. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  242. ::c10::demangle_type<__VA_ARGS__>());
  243. // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
  244. // std::string as the key type, because that is the most commonly used cases.
  245. #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  246. C10_DECLARE_TYPED_REGISTRY( \
  247. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  248. #define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  249. TORCH_DECLARE_TYPED_REGISTRY( \
  250. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  251. #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
  252. C10_DEFINE_TYPED_REGISTRY( \
  253. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  254. #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
  255. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  256. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  257. #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  258. C10_DECLARE_TYPED_REGISTRY( \
  259. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  260. #define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  261. TORCH_DECLARE_TYPED_REGISTRY( \
  262. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  263. #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  264. C10_DEFINE_TYPED_REGISTRY( \
  265. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  266. #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
  267. RegistryName, ObjectType, ...) \
  268. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  269. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  270. // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
  271. // as the key
  272. // type, because that is the most commonly used cases.
  273. #define C10_REGISTER_CREATOR(RegistryName, key, ...) \
  274. C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
  275. #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
  276. C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  277. RegistryName, #key, priority, __VA_ARGS__)
  278. #define C10_REGISTER_CLASS(RegistryName, key, ...) \
  279. C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
  280. #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
  281. C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  282. RegistryName, #key, priority, __VA_ARGS__)
  283. } // namespace c10
  284. #endif // C10_UTIL_REGISTRY_H_