dataset.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """This is definition of dataset class, which is high performance IO."""
  15. from google.protobuf import text_format
  16. import paddle
  17. from paddle.base.proto import data_feed_pb2
  18. from ..utils import deprecated
  19. from . import core
  20. __all__ = []
  21. class DatasetFactory:
  22. """
  23. DatasetFactory is a factory which create dataset by its name,
  24. you can create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
  25. the default is "QueueDataset".
  26. Example:
  27. .. code-block:: python
  28. >>> import paddle.base as base
  29. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  30. """
  31. def __init__(self):
  32. """Init."""
  33. pass
  34. def create_dataset(self, datafeed_class="QueueDataset"):
  35. """
  36. Create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
  37. the default is "QueueDataset".
  38. Args:
  39. datafeed_class(str): datafeed class name, QueueDataset or InMemoryDataset.
  40. Default is QueueDataset.
  41. Examples:
  42. .. code-block:: python
  43. >>> import paddle.base as base
  44. >>> dataset = base.DatasetFactory().create_dataset()
  45. """
  46. try:
  47. dataset = globals()[datafeed_class]()
  48. return dataset
  49. except:
  50. raise ValueError(
  51. "datafeed class %s does not exist" % datafeed_class
  52. )
  53. class DatasetBase:
  54. """Base dataset class."""
  55. def __init__(self):
  56. """Init."""
  57. # define class name here
  58. # to decide whether we need create in memory instance
  59. self.proto_desc = data_feed_pb2.DataFeedDesc()
  60. self.proto_desc.pipe_command = "cat"
  61. self.dataset = core.Dataset("MultiSlotDataset")
  62. self.thread_num = 1
  63. self.filelist = []
  64. self.use_ps_gpu = False
  65. self.psgpu = None
  66. def set_pipe_command(self, pipe_command):
  67. """
  68. Set pipe command of current dataset
  69. A pipe command is a UNIX pipeline command that can be used only
  70. Examples:
  71. .. code-block:: python
  72. >>> import paddle.base as base
  73. >>> dataset = base.DatasetFactory().create_dataset()
  74. >>> dataset.set_pipe_command("python my_script.py")
  75. Args:
  76. pipe_command(str): pipe command
  77. """
  78. self.proto_desc.pipe_command = pipe_command
  79. def set_so_parser_name(self, so_parser_name):
  80. """
  81. Set so parser name of current dataset
  82. Examples:
  83. .. code-block:: python
  84. >>> import paddle.base as base
  85. >>> dataset = base.DatasetFactory().create_dataset()
  86. >>> dataset.set_so_parser_name("./abc.so")
  87. Args:
  88. pipe_command(str): pipe command
  89. """
  90. self.proto_desc.so_parser_name = so_parser_name
  91. def set_rank_offset(self, rank_offset):
  92. """
  93. Set rank_offset for merge_pv. It set the message of Pv.
  94. Examples:
  95. .. code-block:: python
  96. >>> import paddle.base as base
  97. >>> dataset = base.DatasetFactory().create_dataset()
  98. >>> dataset.set_rank_offset("rank_offset")
  99. Args:
  100. rank_offset(str): rank_offset's name
  101. """
  102. self.proto_desc.rank_offset = rank_offset
  103. def set_fea_eval(self, record_candidate_size, fea_eval=True):
  104. """
  105. set fea eval mode for slots shuffle to debug the importance level of
  106. slots(features), fea_eval need to be set True for slots shuffle.
  107. Args:
  108. record_candidate_size(int): size of instances candidate to shuffle
  109. one slot
  110. fea_eval(bool): whether enable fea eval mode to enable slots shuffle.
  111. default is True.
  112. Examples:
  113. .. code-block:: python
  114. >>> import paddle.base as base
  115. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  116. >>> dataset.set_fea_eval(1000000, True)
  117. """
  118. if fea_eval:
  119. self.dataset.set_fea_eval(fea_eval, record_candidate_size)
  120. self.fea_eval = fea_eval
  121. def slots_shuffle(self, slots):
  122. """
  123. Slots Shuffle
  124. Slots Shuffle is a shuffle method in slots level, which is usually used
  125. in sparse feature with large scale of instances. To compare the metric, i.e.
  126. auc while doing slots shuffle on one or several slots with baseline to
  127. evaluate the importance level of slots(features).
  128. Args:
  129. slots(list[string]): the set of slots(string) to do slots shuffle.
  130. Examples:
  131. import paddle.base as base
  132. dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  133. dataset.set_merge_by_lineid()
  134. #suppose there is a slot 0
  135. dataset.slots_shuffle(['0'])
  136. """
  137. if self.fea_eval:
  138. slots_set = set(slots)
  139. self.dataset.slots_shuffle(slots_set)
  140. def set_batch_size(self, batch_size):
  141. """
  142. Set batch size. Will be effective during training
  143. Examples:
  144. .. code-block:: python
  145. >>> import paddle.base as base
  146. >>> dataset = base.DatasetFactory().create_dataset()
  147. >>> dataset.set_batch_size(128)
  148. Args:
  149. batch_size(int): batch size
  150. """
  151. self.proto_desc.batch_size = batch_size
  152. def set_pv_batch_size(self, pv_batch_size):
  153. """
  154. Set pv batch size. It will be effective during enable_pv_merge
  155. Examples:
  156. .. code-block:: python
  157. >>> import paddle.base as base
  158. >>> dataset = base.DatasetFactory().create_dataset()
  159. >>> dataset.set_pv_batch_size(128)
  160. Args:
  161. pv_batch_size(int): pv batch size
  162. """
  163. self.proto_desc.pv_batch_size = pv_batch_size
  164. def set_thread(self, thread_num):
  165. """
  166. Set thread num, it is the num of readers.
  167. Examples:
  168. .. code-block:: python
  169. >>> import paddle.base as base
  170. >>> dataset = base.DatasetFactory().create_dataset()
  171. >>> dataset.set_thread(12)
  172. Args:
  173. thread_num(int): thread num
  174. """
  175. self.dataset.set_thread_num(thread_num)
  176. self.thread_num = thread_num
  177. def set_filelist(self, filelist):
  178. """
  179. Set file list in current worker.
  180. Examples:
  181. .. code-block:: python
  182. >>> import paddle.base as base
  183. >>> dataset = base.DatasetFactory().create_dataset()
  184. >>> dataset.set_filelist(['a.txt', 'b.txt'])
  185. Args:
  186. filelist(list): file list
  187. """
  188. self.dataset.set_filelist(filelist)
  189. self.filelist = filelist
  190. def set_input_type(self, input_type):
  191. self.proto_desc.input_type = input_type
  192. def set_use_var(self, var_list):
  193. """
  194. Set Variables which you will use.
  195. Examples:
  196. .. code-block:: python
  197. >>> import paddle.base as base
  198. >>> paddle.enable_static()
  199. >>> dataset = base.DatasetFactory().create_dataset()
  200. >>> data = paddle.static.data(name="data", shape=[None, 10, 10], dtype="int64")
  201. >>> label = paddle.static.data(name="label", shape=[None, 1], dtype="int64", lod_level=1)
  202. >>> dataset.set_use_var([data, label])
  203. Args:
  204. var_list(list): variable list
  205. """
  206. multi_slot = self.proto_desc.multi_slot_desc
  207. for var in var_list:
  208. slot_var = multi_slot.slots.add()
  209. slot_var.is_used = True
  210. slot_var.name = var.name
  211. if var.lod_level == 0:
  212. slot_var.is_dense = True
  213. slot_var.shape.extend(var.shape)
  214. if var.dtype == paddle.float32:
  215. slot_var.type = "float"
  216. elif var.dtype == paddle.int64:
  217. slot_var.type = "uint64"
  218. elif var.dtype == paddle.int32:
  219. slot_var.type = "uint32"
  220. else:
  221. raise ValueError(
  222. "Currently, base.dataset only supports dtype=float32, dtype=int32 and dtype=int64"
  223. )
  224. def set_hdfs_config(self, fs_name, fs_ugi):
  225. """
  226. Set hdfs config: fs name ad ugi
  227. Examples:
  228. .. code-block:: python
  229. >>> import paddle.base as base
  230. >>> dataset = base.DatasetFactory().create_dataset()
  231. >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
  232. Args:
  233. fs_name(str): fs name
  234. fs_ugi(str): fs ugi
  235. """
  236. self.dataset.set_hdfs_config(fs_name, fs_ugi)
  237. def set_download_cmd(self, download_cmd):
  238. """
  239. Set customized download cmd: download_cmd
  240. Examples:
  241. .. code-block:: python
  242. >>> import paddle.base as base
  243. >>> dataset = base.DatasetFactory().create_dataset()
  244. >>> dataset.set_download_cmd("./read_from_afs")
  245. Args:
  246. download_cmd(str): customized download command
  247. """
  248. self.dataset.set_download_cmd(download_cmd)
  249. def _prepare_to_run(self):
  250. """
  251. Set data_feed_desc before load or shuffle,
  252. user no need to call this function.
  253. """
  254. if self.thread_num > len(self.filelist):
  255. self.thread_num = len(self.filelist)
  256. self.dataset.set_thread_num(self.thread_num)
  257. self.dataset.set_data_feed_desc(self.desc())
  258. self.dataset.create_readers()
  259. def _set_use_ps_gpu(self, psgpu):
  260. """
  261. set use_ps_gpu flag
  262. Args:
  263. use_ps_gpu: bool
  264. """
  265. self.use_ps_gpu = True
  266. # if not defined heterps with paddle, users will not use psgpu
  267. if not core._is_compiled_with_heterps():
  268. self.use_ps_gpu = False
  269. elif self.use_ps_gpu:
  270. self.psgpu = psgpu
  271. def _finish_to_run(self):
  272. self.dataset.destroy_readers()
  273. def desc(self):
  274. """
  275. Returns a protobuf message for this DataFeedDesc
  276. Examples:
  277. .. code-block:: python
  278. >>> import paddle.base as base
  279. >>> dataset = base.DatasetFactory().create_dataset()
  280. >>> print(dataset.desc())
  281. Returns:
  282. A string message
  283. """
  284. return text_format.MessageToString(self.proto_desc)
  285. def _dynamic_adjust_before_train(self, thread_num):
  286. pass
  287. def _dynamic_adjust_after_train(self):
  288. pass
  289. class InMemoryDataset(DatasetBase):
  290. """
  291. InMemoryDataset, it will load data into memory
  292. and shuffle data before training.
  293. This class should be created by DatasetFactory
  294. Example:
  295. dataset = paddle.base.DatasetFactory().create_dataset("InMemoryDataset")
  296. """
  297. @deprecated(since="2.0.0", update_to="paddle.distributed.InMemoryDataset")
  298. def __init__(self):
  299. """Init."""
  300. super().__init__()
  301. self.proto_desc.name = "MultiSlotInMemoryDataFeed"
  302. self.fleet_send_batch_size = None
  303. self.is_user_set_queue_num = False
  304. self.queue_num = None
  305. self.parse_ins_id = False
  306. self.parse_content = False
  307. self.parse_logkey = False
  308. self.merge_by_sid = True
  309. self.enable_pv_merge = False
  310. self.merge_by_lineid = False
  311. self.fleet_send_sleep_seconds = None
  312. self.trainer_num = -1
  313. self.pass_id = 0
  314. @deprecated(
  315. since="2.0.0",
  316. update_to="paddle.distributed.InMemoryDataset._set_feed_type",
  317. )
  318. def set_feed_type(self, data_feed_type):
  319. """
  320. Set data_feed_desc
  321. """
  322. self.proto_desc.name = data_feed_type
  323. if self.proto_desc.name == "SlotRecordInMemoryDataFeed":
  324. self.dataset = core.Dataset("SlotRecordDataset")
  325. @deprecated(
  326. since="2.0.0",
  327. update_to="paddle.distributed.InMemoryDataset._prepare_to_run",
  328. )
  329. def _prepare_to_run(self):
  330. """
  331. Set data_feed_desc before load or shuffle,
  332. user no need to call this function.
  333. """
  334. if self.thread_num <= 0:
  335. self.thread_num = 1
  336. self.dataset.set_thread_num(self.thread_num)
  337. if self.queue_num is None:
  338. self.queue_num = self.thread_num
  339. self.dataset.set_queue_num(self.queue_num)
  340. self.dataset.set_parse_ins_id(self.parse_ins_id)
  341. self.dataset.set_parse_content(self.parse_content)
  342. self.dataset.set_parse_logkey(self.parse_logkey)
  343. self.dataset.set_merge_by_sid(self.merge_by_sid)
  344. self.dataset.set_enable_pv_merge(self.enable_pv_merge)
  345. self.dataset.set_data_feed_desc(self.desc())
  346. self.dataset.create_channel()
  347. self.dataset.create_readers()
  348. @deprecated(
  349. since="2.0.0",
  350. update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_before_train",
  351. )
  352. def _dynamic_adjust_before_train(self, thread_num):
  353. if not self.is_user_set_queue_num:
  354. if self.use_ps_gpu:
  355. self.dataset.dynamic_adjust_channel_num(thread_num, True)
  356. else:
  357. self.dataset.dynamic_adjust_channel_num(thread_num, False)
  358. self.dataset.dynamic_adjust_readers_num(thread_num)
  359. @deprecated(
  360. since="2.0.0",
  361. update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_after_train",
  362. )
  363. def _dynamic_adjust_after_train(self):
  364. if not self.is_user_set_queue_num:
  365. if self.use_ps_gpu:
  366. self.dataset.dynamic_adjust_channel_num(self.thread_num, True)
  367. else:
  368. self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
  369. self.dataset.dynamic_adjust_readers_num(self.thread_num)
  370. @deprecated(
  371. since="2.0.0",
  372. update_to="paddle.distributed.InMemoryDataset._set_queue_num",
  373. )
  374. def set_queue_num(self, queue_num):
  375. """
  376. Set Dataset output queue num, training threads get data from queues
  377. Args:
  378. queue_num(int): dataset output queue num
  379. Examples:
  380. .. code-block:: python
  381. >>> import paddle.base as base
  382. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  383. >>> dataset.set_queue_num(12)
  384. """
  385. self.is_user_set_queue_num = True
  386. self.queue_num = queue_num
  387. @deprecated(
  388. since="2.0.0",
  389. update_to="paddle.distributed.InMemoryDataset._set_parse_ins_id",
  390. )
  391. def set_parse_ins_id(self, parse_ins_id):
  392. """
  393. Set id Dataset need to parse insid
  394. Args:
  395. parse_ins_id(bool): if parse ins_id or not
  396. Examples:
  397. .. code-block:: python
  398. >>> import paddle.base as base
  399. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  400. >>> dataset.set_parse_ins_id(True)
  401. """
  402. self.parse_ins_id = parse_ins_id
  403. @deprecated(
  404. since="2.0.0",
  405. update_to="paddle.distributed.InMemoryDataset._set_parse_content",
  406. )
  407. def set_parse_content(self, parse_content):
  408. """
  409. Set if Dataset need to parse content
  410. Args:
  411. parse_content(bool): if parse content or not
  412. Examples:
  413. .. code-block:: python
  414. >>> import paddle.base as base
  415. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  416. >>> dataset.set_parse_content(True)
  417. """
  418. self.parse_content = parse_content
  419. def set_parse_logkey(self, parse_logkey):
  420. """
  421. Set if Dataset need to parse logkey
  422. Args:
  423. parse_content(bool): if parse logkey or not
  424. Examples:
  425. .. code-block:: python
  426. >>> import paddle.base as base
  427. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  428. >>> dataset.set_parse_logkey(True)
  429. """
  430. self.parse_logkey = parse_logkey
  431. def _set_trainer_num(self, trainer_num):
  432. """
  433. Set trainer num
  434. Args:
  435. trainer_num(int): trainer num
  436. Examples:
  437. .. code-block:: python
  438. >>> import paddle.base as base
  439. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  440. >>> dataset._set_trainer_num(1)
  441. """
  442. self.trainer_num = trainer_num
  443. @deprecated(
  444. since="2.0.0",
  445. update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid",
  446. )
  447. def set_merge_by_sid(self, merge_by_sid):
  448. """
  449. Set if Dataset need to merge sid. If not, one ins means one Pv.
  450. Args:
  451. merge_by_sid(bool): if merge sid or not
  452. Examples:
  453. .. code-block:: python
  454. >>> import paddle.base as base
  455. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  456. >>> dataset.set_merge_by_sid(True)
  457. """
  458. self.merge_by_sid = merge_by_sid
  459. def set_enable_pv_merge(self, enable_pv_merge):
  460. """
  461. Set if Dataset need to merge pv.
  462. Args:
  463. enable_pv_merge(bool): if enable_pv_merge or not
  464. Examples:
  465. .. code-block:: python
  466. >>> import paddle.base as base
  467. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  468. >>> dataset.set_enable_pv_merge(True)
  469. """
  470. self.enable_pv_merge = enable_pv_merge
  471. def preprocess_instance(self):
  472. """
  473. Merge pv instance and convey it from input_channel to input_pv_channel.
  474. It will be effective when enable_pv_merge_ is True.
  475. Examples:
  476. .. code-block:: python
  477. >>> # doctest: +SKIP('Depends on external files.')
  478. >>> import paddle.base as base
  479. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  480. >>> filelist = ["a.txt", "b.txt"]
  481. >>> dataset.set_filelist(filelist)
  482. >>> dataset.load_into_memory()
  483. >>> dataset.preprocess_instance()
  484. """
  485. self.dataset.preprocess_instance()
  486. def set_current_phase(self, current_phase):
  487. """
  488. Set current phase in train. It is useful for untest.
  489. current_phase : 1 for join, 0 for update.
  490. Examples:
  491. .. code-block:: python
  492. >>> # doctest: +SKIP('Depends on external files.')
  493. >>> import paddle.base as base
  494. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  495. >>> filelist = ["a.txt", "b.txt"]
  496. >>> dataset.set_filelist(filelist)
  497. >>> dataset.load_into_memory()
  498. >>> dataset.set_current_phase(1)
  499. """
  500. self.dataset.set_current_phase(current_phase)
  501. def postprocess_instance(self):
  502. """
  503. Divide pv instance and convey it to input_channel.
  504. Examples:
  505. .. code-block:: python
  506. >>> # doctest: +SKIP('Depends on external files.')
  507. >>> import paddle.base as base
  508. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  509. >>> filelist = ["a.txt", "b.txt"]
  510. >>> dataset.set_filelist(filelist)
  511. >>> dataset.load_into_memory()
  512. >>> dataset.preprocess_instance()
  513. >>> exe.train_from_dataset(dataset)
  514. >>> dataset.postprocess_instance()
  515. """
  516. self.dataset.postprocess_instance()
  517. @deprecated(
  518. since="2.0.0",
  519. update_to="paddle.distributed.InMemoryDataset._set_fleet_send_batch_size",
  520. )
  521. def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
  522. """
  523. Set fleet send batch size, default is 1024
  524. Args:
  525. fleet_send_batch_size(int): fleet send batch size
  526. Examples:
  527. .. code-block:: python
  528. >>> import paddle.base as base
  529. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  530. >>> dataset.set_fleet_send_batch_size(800)
  531. """
  532. self.fleet_send_batch_size = fleet_send_batch_size
  533. @deprecated(
  534. since="2.0.0",
  535. update_to="paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds",
  536. )
  537. def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
  538. """
  539. Set fleet send sleep time, default is 0
  540. Args:
  541. fleet_send_sleep_seconds(int): fleet send sleep time
  542. Examples:
  543. .. code-block:: python
  544. >>> import paddle.base as base
  545. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  546. >>> dataset.set_fleet_send_sleep_seconds(2)
  547. """
  548. self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
  549. @deprecated(
  550. since="2.0.0",
  551. update_to="paddle.distributed.InMemoryDataset._set_merge_by_lineid",
  552. )
  553. def set_merge_by_lineid(self, merge_size=2):
  554. """
  555. Set merge by line id, instances of same line id will be merged after
  556. shuffle, you should parse line id in data generator.
  557. Args:
  558. merge_size(int): ins size to merge. default is 2.
  559. Examples:
  560. .. code-block:: python
  561. >>> import paddle.base as base
  562. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  563. >>> dataset.set_merge_by_lineid()
  564. """
  565. self.dataset.set_merge_by_lineid(merge_size)
  566. self.merge_by_lineid = True
  567. self.parse_ins_id = True
  568. @deprecated(
  569. since="2.0.0",
  570. update_to="paddle.distributed.InMemoryDataset._set_generate_unique_feasigns",
  571. )
  572. def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
  573. self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
  574. self.gen_uni_feasigns = generate_uni_feasigns
  575. self.local_shard_num = shard_num
  576. @deprecated(
  577. since="2.0.0",
  578. update_to="paddle.distributed.InMemoryDataset._generate_local_tables_unlock",
  579. )
  580. def generate_local_tables_unlock(
  581. self, table_id, fea_dim, read_thread_num, consume_thread_num, shard_num
  582. ):
  583. self.dataset.generate_local_tables_unlock(
  584. table_id, fea_dim, read_thread_num, consume_thread_num, shard_num
  585. )
  586. def set_date(self, date):
  587. """
  588. :api_attr: Static Graph
  589. Set training date for pull sparse parameters, saving and loading model. Only used in psgpu
  590. Args:
  591. date(str): training date(format : YYMMDD). eg.20211111
  592. Examples:
  593. .. code-block:: python
  594. >>> import paddle.base as base
  595. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  596. >>> dataset.set_date("20211111")
  597. """
  598. year = int(date[:4])
  599. month = int(date[4:6])
  600. day = int(date[6:])
  601. if self.use_ps_gpu and core._is_compiled_with_heterps():
  602. self.psgpu.set_date(year, month, day)
  603. @deprecated(
  604. since="2.0.0",
  605. update_to="paddle.distributed.InMemoryDataset.load_into_memory",
  606. )
  607. def load_into_memory(self, is_shuffle=False):
  608. """
  609. Load data into memory
  610. Args:
  611. is_shuffle(bool): whether to use local shuffle, default is False
  612. Examples:
  613. .. code-block:: python
  614. >>> # doctest: +SKIP('Depends on external files.')
  615. >>> import paddle.base as base
  616. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  617. >>> filelist = ["a.txt", "b.txt"]
  618. >>> dataset.set_filelist(filelist)
  619. >>> dataset.load_into_memory()
  620. """
  621. self._prepare_to_run()
  622. if not self.use_ps_gpu:
  623. self.dataset.load_into_memory()
  624. elif core._is_compiled_with_heterps():
  625. self.psgpu.set_dataset(self.dataset)
  626. self.psgpu.load_into_memory(is_shuffle)
  627. @deprecated(
  628. since="2.0.0",
  629. update_to="paddle.distributed.InMemoryDataset.preload_into_memory",
  630. )
  631. def preload_into_memory(self, thread_num=None):
  632. """
  633. Load data into memory in async mode
  634. Args:
  635. thread_num(int): preload thread num
  636. Examples:
  637. .. code-block:: python
  638. >>> # doctest: +SKIP('Depends on external files.')
  639. >>> import paddle.base as base
  640. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  641. >>> filelist = ["a.txt", "b.txt"]
  642. >>> dataset.set_filelist(filelist)
  643. >>> dataset.preload_into_memory()
  644. >>> dataset.wait_preload_done()
  645. """
  646. self._prepare_to_run()
  647. if thread_num is None:
  648. thread_num = self.thread_num
  649. self.dataset.set_preload_thread_num(thread_num)
  650. self.dataset.create_preload_readers()
  651. self.dataset.preload_into_memory()
  652. @deprecated(
  653. since="2.0.0",
  654. update_to="paddle.distributed.InMemoryDataset.wait_preload_done",
  655. )
  656. def wait_preload_done(self):
  657. """
  658. Wait preload_into_memory done
  659. Examples:
  660. .. code-block:: python
  661. >>> # doctest: +SKIP('Depends on external files.')
  662. >>> import paddle.base as base
  663. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  664. >>> filelist = ["a.txt", "b.txt"]
  665. >>> dataset.set_filelist(filelist)
  666. >>> dataset.preload_into_memory()
  667. >>> dataset.wait_preload_done()
  668. """
  669. self.dataset.wait_preload_done()
  670. self.dataset.destroy_preload_readers()
  671. @deprecated(
  672. since="2.0.0",
  673. update_to="paddle.distributed.InMemoryDataset.local_shuffle",
  674. )
  675. def local_shuffle(self):
  676. """
  677. Local shuffle
  678. Examples:
  679. .. code-block:: python
  680. >>> # doctest: +SKIP('Depends on external files.')
  681. >>> import paddle.base as base
  682. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  683. >>> filelist = ["a.txt", "b.txt"]
  684. >>> dataset.set_filelist(filelist)
  685. >>> dataset.load_into_memory()
  686. >>> dataset.local_shuffle()
  687. """
  688. self.dataset.local_shuffle()
  689. @deprecated(
  690. since="2.0.0",
  691. update_to="paddle.distributed.InMemoryDataset.global_shuffle",
  692. )
  693. def global_shuffle(self, fleet=None, thread_num=12):
  694. """
  695. Global shuffle.
  696. Global shuffle can be used only in distributed mode. i.e. multiple
  697. processes on single machine or multiple machines training together.
  698. If you run in distributed mode, you should pass fleet instead of None.
  699. Examples:
  700. .. code-block:: python
  701. >>> # doctest: +SKIP('Depends on external files.')
  702. >>> import paddle.base as base
  703. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  704. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  705. >>> filelist = ["a.txt", "b.txt"]
  706. >>> dataset.set_filelist(filelist)
  707. >>> dataset.load_into_memory()
  708. >>> dataset.global_shuffle(fleet)
  709. Args:
  710. fleet(Fleet): fleet singleton. Default None.
  711. thread_num(int): shuffle thread num. Default is 12.
  712. """
  713. if fleet is not None:
  714. if hasattr(fleet, "barrier_worker"):
  715. print("pscore fleet")
  716. fleet.barrier_worker()
  717. else:
  718. fleet._role_maker.barrier_worker()
  719. if self.trainer_num == -1:
  720. self.trainer_num = fleet.worker_num()
  721. if self.fleet_send_batch_size is None:
  722. self.fleet_send_batch_size = 1024
  723. if self.fleet_send_sleep_seconds is None:
  724. self.fleet_send_sleep_seconds = 0
  725. self.dataset.register_client2client_msg_handler()
  726. self.dataset.set_trainer_num(self.trainer_num)
  727. self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
  728. self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
  729. if fleet is not None:
  730. if hasattr(fleet, "barrier_worker"):
  731. fleet.barrier_worker()
  732. else:
  733. fleet._role_maker.barrier_worker()
  734. self.dataset.global_shuffle(thread_num)
  735. if fleet is not None:
  736. if hasattr(fleet, "barrier_worker"):
  737. fleet.barrier_worker()
  738. else:
  739. fleet._role_maker.barrier_worker()
  740. if self.merge_by_lineid:
  741. self.dataset.merge_by_lineid()
  742. if fleet is not None:
  743. if hasattr(fleet, "barrier_worker"):
  744. fleet.barrier_worker()
  745. else:
  746. fleet._role_maker.barrier_worker()
  747. @deprecated(
  748. since="2.0.0",
  749. update_to="paddle.distributed.InMemoryDataset.release_memory",
  750. )
  751. def release_memory(self):
  752. """
  753. :api_attr: Static Graph
  754. Release InMemoryDataset memory data, when data will not be used again.
  755. Examples:
  756. .. code-block:: python
  757. >>> # doctest: +SKIP('Depends on external files.')
  758. >>> import paddle.base as base
  759. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  760. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  761. >>> filelist = ["a.txt", "b.txt"]
  762. >>> dataset.set_filelist(filelist)
  763. >>> dataset.load_into_memory()
  764. >>> dataset.global_shuffle(fleet)
  765. >>> exe = base.Executor(base.CPUPlace())
  766. >>> exe.run(base.default_startup_program())
  767. >>> exe.train_from_dataset(base.default_main_program(), dataset)
  768. >>> dataset.release_memory()
  769. """
  770. self.dataset.release_memory()
  771. def get_pv_data_size(self):
  772. """
  773. Get memory data size of Pv, user can call this function to know the pv num
  774. of ins in all workers after load into memory.
  775. Note:
  776. This function may cause bad performance, because it has barrier
  777. Returns:
  778. The size of memory pv data.
  779. Examples:
  780. .. code-block:: python
  781. >>> # doctest: +SKIP('Depends on external files.')
  782. >>> import paddle.base as base
  783. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  784. >>> filelist = ["a.txt", "b.txt"]
  785. >>> dataset.set_filelist(filelist)
  786. >>> dataset.load_into_memory()
  787. >>> print(dataset.get_pv_data_size())
  788. """
  789. return self.dataset.get_pv_data_size()
  790. def get_epoch_finish(self):
  791. return self.dataset.get_epoch_finish()
  792. def clear_sample_state(self):
  793. self.dataset.clear_sample_state()
  794. @deprecated(
  795. since="2.0.0",
  796. update_to="paddle.distributed.InMemoryDataset.get_memory_data_size",
  797. )
  798. def get_memory_data_size(self, fleet=None):
  799. """
  800. Get memory data size, user can call this function to know the num
  801. of ins in all workers after load into memory.
  802. Note:
  803. This function may cause bad performance, because it has barrier
  804. Args:
  805. fleet(Fleet): Fleet Object.
  806. Returns:
  807. The size of memory data.
  808. Examples:
  809. .. code-block:: python
  810. >>> # doctest: +SKIP('Depends on external files.')
  811. >>> import paddle.base as base
  812. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  813. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  814. >>> filelist = ["a.txt", "b.txt"]
  815. >>> dataset.set_filelist(filelist)
  816. >>> dataset.load_into_memory()
  817. >>> print(dataset.get_memory_data_size(fleet))
  818. """
  819. import numpy as np
  820. local_data_size = self.dataset.get_memory_data_size()
  821. local_data_size = np.array([local_data_size])
  822. if fleet is not None:
  823. global_data_size = local_data_size * 0
  824. fleet._role_maker.all_reduce_worker(
  825. local_data_size, global_data_size
  826. )
  827. return global_data_size[0]
  828. return local_data_size[0]
  829. @deprecated(
  830. since="2.0.0",
  831. update_to="paddle.distributed.InMemoryDataset.get_shuffle_data_size",
  832. )
  833. def get_shuffle_data_size(self, fleet=None):
  834. """
  835. Get shuffle data size, user can call this function to know the num
  836. of ins in all workers after local/global shuffle.
  837. Note:
  838. This function may cause bad performance to local shuffle,
  839. because it has barrier. It does not affect global shuffle.
  840. Args:
  841. fleet(Fleet): Fleet Object.
  842. Returns:
  843. The size of shuffle data.
  844. Examples:
  845. .. code-block:: python
  846. >>> # doctest: +SKIP('Depends on external files.')
  847. >>> import paddle.base as base
  848. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  849. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  850. >>> filelist = ["a.txt", "b.txt"]
  851. >>> dataset.set_filelist(filelist)
  852. >>> dataset.load_into_memory()
  853. >>> dataset.global_shuffle(fleet)
  854. >>> print(dataset.get_shuffle_data_size(fleet))
  855. """
  856. import numpy as np
  857. local_data_size = self.dataset.get_shuffle_data_size()
  858. local_data_size = np.array([local_data_size])
  859. print('global shuffle local_data_size: ', local_data_size)
  860. if fleet is not None:
  861. global_data_size = local_data_size * 0
  862. if hasattr(fleet, "util"):
  863. global_data_size = fleet.util.all_reduce(local_data_size)
  864. else:
  865. fleet._role_maker.all_reduce_worker(
  866. local_data_size, global_data_size
  867. )
  868. return global_data_size[0]
  869. return local_data_size[0]
  870. def _set_heter_ps(self, enable_heter_ps=False):
  871. """
  872. Set heter ps mode
  873. user no need to call this function.
  874. """
  875. self.dataset.set_heter_ps(enable_heter_ps)
  876. def set_graph_config(self, config):
  877. """
  878. Set graph config, user can set graph config in gpu graph mode.
  879. Args:
  880. config(dict): config dict.
  881. Returns:
  882. The size of shuffle data.
  883. Examples:
  884. .. code-block:: python
  885. >>> import paddle.base as base
  886. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  887. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  888. >>> graph_config = {"walk_len": 24,
  889. ... "walk_degree": 10,
  890. ... "once_sample_startid_len": 80000,
  891. ... "sample_times_one_chunk": 5,
  892. ... "window": 3,
  893. ... "debug_mode": 0,
  894. ... "batch_size": 800,
  895. ... "meta_path": "cuid2clk-clk2cuid;cuid2conv-conv2cuid;clk2cuid-cuid2clk;clk2cuid-cuid2conv",
  896. ... "gpu_graph_training": 1}
  897. >>> dataset.set_graph_config(graph_config)
  898. """
  899. self.proto_desc.graph_config.walk_degree = config.get("walk_degree", 1)
  900. self.proto_desc.graph_config.walk_len = config.get("walk_len", 20)
  901. self.proto_desc.graph_config.window = config.get("window", 5)
  902. self.proto_desc.graph_config.once_sample_startid_len = config.get(
  903. "once_sample_startid_len", 8000
  904. )
  905. self.proto_desc.graph_config.sample_times_one_chunk = config.get(
  906. "sample_times_one_chunk", 10
  907. )
  908. self.proto_desc.graph_config.batch_size = config.get("batch_size", 1)
  909. self.proto_desc.graph_config.debug_mode = config.get("debug_mode", 0)
  910. self.proto_desc.graph_config.first_node_type = config.get(
  911. "first_node_type", ""
  912. )
  913. self.proto_desc.graph_config.meta_path = config.get("meta_path", "")
  914. self.proto_desc.graph_config.gpu_graph_training = config.get(
  915. "gpu_graph_training", True
  916. )
  917. self.proto_desc.graph_config.sage_mode = config.get("sage_mode", False)
  918. self.proto_desc.graph_config.samples = config.get("samples", "")
  919. self.proto_desc.graph_config.train_table_cap = config.get(
  920. "train_table_cap", 800000
  921. )
  922. self.proto_desc.graph_config.infer_table_cap = config.get(
  923. "infer_table_cap", 800000
  924. )
  925. self.proto_desc.graph_config.excluded_train_pair = config.get(
  926. "excluded_train_pair", ""
  927. )
  928. self.proto_desc.graph_config.infer_node_type = config.get(
  929. "infer_node_type", ""
  930. )
  931. self.proto_desc.graph_config.get_degree = config.get(
  932. "get_degree", False
  933. )
  934. self.proto_desc.graph_config.weighted_sample = config.get(
  935. "weighted_sample", False
  936. )
  937. self.proto_desc.graph_config.return_weight = config.get(
  938. "return_weight", False
  939. )
  940. self.proto_desc.graph_config.pair_label = config.get("pair_label", "")
  941. self.proto_desc.graph_config.accumulate_num = config.get(
  942. "accumulate_num", 1
  943. )
  944. self.dataset.set_gpu_graph_mode(True)
  945. def set_pass_id(self, pass_id):
  946. """
  947. Set pass id, user can set pass id in gpu graph mode.
  948. Args:
  949. pass_id: pass id.
  950. Examples:
  951. .. code-block:: python
  952. >>> import paddle.base as base
  953. >>> pass_id = 0
  954. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  955. >>> dataset.set_pass_id(pass_id)
  956. """
  957. self.pass_id = pass_id
  958. self.dataset.set_pass_id(pass_id)
  959. def get_pass_id(self):
  960. """
  961. Get pass id, user can set pass id in gpu graph mode.
  962. Returns:
  963. The pass id.
  964. Examples:
  965. .. code-block:: python
  966. >>> import paddle.base as base
  967. >>> dataset = base.DatasetFactory().create_dataset("InMemoryDataset")
  968. >>> pass_id = dataset.get_pass_id()
  969. """
  970. return self.pass_id
  971. def dump_walk_path(self, path, dump_rate=1000):
  972. """
  973. dump_walk_path
  974. """
  975. self.dataset.dump_walk_path(path, dump_rate)
  976. def dump_sample_neighbors(self, path):
  977. """
  978. dump_sample_neighbors
  979. """
  980. self.dataset.dump_sample_neighbors(path)
  981. class QueueDataset(DatasetBase):
  982. """
  983. QueueDataset, it will process data streamly.
  984. Examples:
  985. .. code-block:: python
  986. >>> import paddle.base as base
  987. >>> dataset = base.DatasetFactory().create_dataset("QueueDataset")
  988. """
  989. def __init__(self):
  990. """
  991. Initialize QueueDataset
  992. This class should be created by DatasetFactory
  993. """
  994. super().__init__()
  995. self.proto_desc.name = "MultiSlotDataFeed"
  996. @deprecated(
  997. since="2.0.0",
  998. update_to="paddle.distributed.QueueDataset._prepare_to_run",
  999. )
  1000. def _prepare_to_run(self):
  1001. """
  1002. Set data_feed_desc/thread num/filelist before run,
  1003. user no need to call this function.
  1004. """
  1005. if self.thread_num > len(self.filelist):
  1006. self.thread_num = len(self.filelist)
  1007. if self.thread_num == 0:
  1008. self.thread_num = 1
  1009. self.dataset.set_thread_num(self.thread_num)
  1010. self.dataset.set_filelist(self.filelist)
  1011. self.dataset.set_data_feed_desc(self.desc())
  1012. self.dataset.create_readers()
  1013. def local_shuffle(self):
  1014. """
  1015. Local shuffle data.
  1016. Local shuffle is not supported in QueueDataset
  1017. NotImplementedError will be raised
  1018. Examples:
  1019. .. code-block:: python
  1020. >>> # doctest: +SKIP('NotImplementedError will be raised.')
  1021. >>> import paddle.base as base
  1022. >>> dataset = base.DatasetFactory().create_dataset("QueueDataset")
  1023. >>> dataset.local_shuffle()
  1024. Raises:
  1025. NotImplementedError: QueueDataset does not support local shuffle
  1026. """
  1027. raise NotImplementedError(
  1028. "QueueDataset does not support local shuffle, "
  1029. "please use InMemoryDataset for local_shuffle"
  1030. )
  1031. def global_shuffle(self, fleet=None):
  1032. """
  1033. Global shuffle data.
  1034. Global shuffle is not supported in QueueDataset
  1035. NotImplementedError will be raised
  1036. Args:
  1037. fleet(Fleet): fleet singleton. Default None.
  1038. Examples:
  1039. .. code-block:: python
  1040. >>> import paddle.base as base
  1041. >>> from paddle.incubate.distributed.fleet.parameter_server.pslib import fleet
  1042. >>> dataset = base.DatasetFactory().create_dataset("QueueDataset")
  1043. >>> #dataset.global_shuffle(fleet)
  1044. Raises:
  1045. NotImplementedError: QueueDataset does not support global shuffle
  1046. """
  1047. raise NotImplementedError(
  1048. "QueueDataset does not support global shuffle, "
  1049. "please use InMemoryDataset for global_shuffle"
  1050. )
  1051. class FileInstantDataset(DatasetBase):
  1052. """
  1053. FileInstantDataset, it will process data streamly.
  1054. Examples:
  1055. .. code-block:: python
  1056. >>> import paddle.base as base
  1057. >>> dataset = base.DatasetFactory.create_dataset("FileInstantDataset")
  1058. """
  1059. def __init__(self):
  1060. """
  1061. Initialize FileInstantDataset
  1062. This class should be created by DatasetFactory
  1063. """
  1064. super().__init__()
  1065. self.proto_desc.name = "MultiSlotFileInstantDataFeed"
  1066. def local_shuffle(self):
  1067. """
  1068. Local shuffle
  1069. FileInstantDataset does not support local shuffle
  1070. """
  1071. raise NotImplementedError(
  1072. "FileInstantDataset does not support local shuffle, "
  1073. "please use InMemoryDataset for local_shuffle"
  1074. )
  1075. def global_shuffle(self, fleet=None):
  1076. """
  1077. Global shuffle
  1078. FileInstantDataset does not support global shuffle
  1079. """
  1080. raise NotImplementedError(
  1081. "FileInstantDataset does not support global shuffle, "
  1082. "please use InMemoryDataset for global_shuffle"
  1083. )
  1084. class BoxPSDataset(InMemoryDataset):
  1085. """
  1086. BoxPSDataset: derived from InMemoryDataset.
  1087. Examples:
  1088. .. code-block:: python
  1089. >>> import paddle.base as base
  1090. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1091. """
  1092. def __init__(self):
  1093. """
  1094. Initialize BoxPSDataset
  1095. This class should be created by DatasetFactory
  1096. """
  1097. super().__init__()
  1098. self.boxps = core.BoxPS(self.dataset)
  1099. self.proto_desc.name = "PaddleBoxDataFeed"
  1100. def set_date(self, date):
  1101. """
  1102. Workaround for date
  1103. """
  1104. year = int(date[:4])
  1105. month = int(date[4:6])
  1106. day = int(date[6:])
  1107. self.boxps.set_date(year, month, day)
  1108. def begin_pass(self):
  1109. """
  1110. Begin Pass
  1111. Notify BoxPS to load sparse parameters of next pass to GPU Memory
  1112. Examples:
  1113. .. code-block:: python
  1114. >>> import paddle.base as base
  1115. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1116. >>> dataset.begin_pass()
  1117. """
  1118. self.boxps.begin_pass()
  1119. def end_pass(self, need_save_delta):
  1120. """
  1121. End Pass
  1122. Notify BoxPS that current pass ended
  1123. Examples:
  1124. .. code-block:: python
  1125. >>> import paddle.base as base
  1126. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1127. >>> dataset.end_pass(True)
  1128. """
  1129. self.boxps.end_pass(need_save_delta)
  1130. def wait_preload_done(self):
  1131. """
  1132. Wait async preload done
  1133. Wait Until Feed Pass Done
  1134. Examples:
  1135. .. code-block:: python
  1136. >>> # doctest: +SKIP('Depends on external files.')
  1137. >>> import paddle.base as base
  1138. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1139. >>> filelist = ["a.txt", "b.txt"]
  1140. >>> dataset.set_filelist(filelist)
  1141. >>> dataset.preload_into_memory()
  1142. >>> dataset.wait_preload_done()
  1143. """
  1144. self.boxps.wait_feed_pass_done()
  1145. def load_into_memory(self):
  1146. """
  1147. Load next pass into memory and notify boxps to fetch its emb from SSD
  1148. Examples:
  1149. .. code-block:: python
  1150. >>> # doctest: +SKIP('Depends on external files.')
  1151. >>> import paddle.base as base
  1152. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1153. >>> filelist = ["a.txt", "b.txt"]
  1154. >>> dataset.set_filelist(filelist)
  1155. >>> dataset.load_into_memory()
  1156. """
  1157. self._prepare_to_run()
  1158. self.boxps.load_into_memory()
  1159. def preload_into_memory(self):
  1160. """
  1161. Begin async preload next pass while current pass may be training
  1162. Examples:
  1163. .. code-block:: python
  1164. >>> # doctest: +SKIP('Depends on external files.')
  1165. >>> import paddle.base as base
  1166. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1167. >>> filelist = ["a.txt", "b.txt"]
  1168. >>> dataset.set_filelist(filelist)
  1169. >>> dataset.preload_into_memory()
  1170. """
  1171. self._prepare_to_run()
  1172. self.boxps.preload_into_memory()
  1173. def _dynamic_adjust_before_train(self, thread_num):
  1174. if not self.is_user_set_queue_num:
  1175. self.dataset.dynamic_adjust_channel_num(thread_num, True)
  1176. self.dataset.dynamic_adjust_readers_num(thread_num)
  1177. def _dynamic_adjust_after_train(self):
  1178. pass
  1179. def slots_shuffle(self, slots):
  1180. """
  1181. Slots Shuffle
  1182. Slots Shuffle is a shuffle method in slots level, which is usually used
  1183. in sparse feature with large scale of instances. To compare the metric, i.e.
  1184. auc while doing slots shuffle on one or several slots with baseline to
  1185. evaluate the importance level of slots(features).
  1186. Args:
  1187. slots(list[string]): the set of slots(string) to do slots shuffle.
  1188. Examples:
  1189. .. code-block:: python
  1190. >>> import paddle.base as base
  1191. >>> dataset = base.DatasetFactory().create_dataset("BoxPSDataset")
  1192. >>> dataset.set_merge_by_lineid()
  1193. >>> #suppose there is a slot 0
  1194. >>> dataset.slots_shuffle(['0'])
  1195. """
  1196. slots_set = set(slots)
  1197. self.boxps.slots_shuffle(slots_set)