prodigy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Prodigy integration for W&B.
  2. User can upload Prodigy annotated datasets directly
  3. from the local database to W&B in Tables format.
  4. Example usage:
  5. ```python
  6. import wandb
  7. from wandb.integration.prodigy import upload_dataset
  8. run = wandb.init(project="prodigy")
  9. upload_dataset("name_of_dataset")
  10. wandb.finish()
  11. ```
  12. """
  13. import base64
  14. import collections.abc
  15. import io
  16. import urllib
  17. from copy import deepcopy
  18. import pandas as pd
  19. from PIL import Image
  20. import wandb
  21. from wandb import util
  22. from wandb.plot.utils import test_missing
  23. from wandb.sdk.lib import telemetry as wb_telemetry
  24. def named_entity(docs):
  25. """Create a named entity visualization.
  26. Taken from https://github.com/wandb/wandb/blob/main/wandb/plots/named_entity.py.
  27. """
  28. spacy = util.get_module(
  29. "spacy",
  30. required="part_of_speech requires the spacy library, install with `pip install spacy`",
  31. )
  32. util.get_module(
  33. "en_core_web_md",
  34. required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
  35. )
  36. # Test for required packages and missing & non-integer values in docs data
  37. if test_missing(docs=docs):
  38. html = spacy.displacy.render(
  39. docs, style="ent", page=True, minify=True, jupyter=False
  40. )
  41. wandb_html = wandb.Html(html)
  42. return wandb_html
  43. def merge(dict1, dict2):
  44. """Return a new dictionary by merging two dictionaries recursively."""
  45. result = deepcopy(dict1)
  46. for key, value in dict2.items():
  47. if isinstance(value, collections.abc.Mapping):
  48. result[key] = merge(result.get(key, {}), value)
  49. else:
  50. result[key] = deepcopy(dict2[key])
  51. return result
  52. def get_schema(list_data_dict, struct, array_dict_types):
  53. """Get a schema of the dataset's structure and data types."""
  54. # Get the structure of the JSON objects in the database
  55. # This is similar to getting a JSON schema but with slightly different format
  56. for _i, item in enumerate(list_data_dict):
  57. # If the list contains dict objects
  58. for k, v in item.items():
  59. # Check if key already exists in template
  60. if k not in struct.keys():
  61. if isinstance(v, list):
  62. if len(v) > 0 and isinstance(v[0], list):
  63. # nested list structure
  64. struct[k] = type(v) # type list
  65. elif len(v) > 0 and not (
  66. isinstance(v[0], list) or isinstance(v[0], dict)
  67. ):
  68. # list of singular values
  69. struct[k] = type(v) # type list
  70. else:
  71. # list of dicts
  72. array_dict_types.append(
  73. k
  74. ) # keep track of keys that are type list[dict]
  75. struct[k] = {}
  76. struct[k] = get_schema(v, struct[k], array_dict_types)
  77. elif isinstance(v, dict):
  78. struct[k] = {}
  79. struct[k] = get_schema([v], struct[k], array_dict_types)
  80. else:
  81. struct[k] = type(v)
  82. else:
  83. # Get the value of struct[k] which is the current template
  84. # Find new keys and then merge the two templates together
  85. cur_struct = struct[k]
  86. if isinstance(v, list):
  87. if len(v) > 0 and isinstance(v[0], list):
  88. # nested list coordinate structure
  89. # if the value in the item is currently None, then update
  90. if v is not None:
  91. struct[k] = type(v) # type list
  92. elif len(v) > 0 and not (
  93. isinstance(v[0], list) or isinstance(v[0], dict)
  94. ):
  95. # single list with values
  96. # if the value in the item is currently None, then update
  97. if v is not None:
  98. struct[k] = type(v) # type list
  99. else:
  100. array_dict_types.append(
  101. k
  102. ) # keep track of keys that are type list[dict]
  103. struct[k] = {}
  104. struct[k] = get_schema(v, struct[k], array_dict_types)
  105. # merge cur_struct and struct[k], remove duplicates
  106. struct[k] = merge(struct[k], cur_struct)
  107. elif isinstance(v, dict):
  108. struct[k] = {}
  109. struct[k] = get_schema([v], struct[k], array_dict_types)
  110. # merge cur_struct and struct[k], remove duplicates
  111. struct[k] = merge(struct[k], cur_struct)
  112. else:
  113. # if the value in the item is currently None, then update
  114. if v is not None:
  115. struct[k] = type(v)
  116. return struct
  117. def standardize(item, structure, array_dict_types):
  118. """Standardize all rows/entries in dataset to fit the schema.
  119. Will look for missing values and fill it in so all rows have
  120. the same items and structure.
  121. """
  122. for k, v in structure.items():
  123. if k not in item:
  124. # If the structure/field does not exist
  125. if isinstance(v, dict) and (k not in array_dict_types):
  126. # If key k is of type dict, and not not a type list[dict]
  127. item[k] = {}
  128. standardize(item[k], v, array_dict_types)
  129. elif isinstance(v, dict) and (k in array_dict_types):
  130. # If key k is of type dict, and is actually of type list[dict],
  131. # just treat as a list and set to None by default
  132. item[k] = None
  133. else:
  134. # Assign a default type
  135. item[k] = v()
  136. else:
  137. # If the structure/field already exists and is a list or dict
  138. if isinstance(item[k], list):
  139. # ignore if item is a nested list structure or list of non-dicts
  140. condition = (
  141. not (len(item[k]) > 0 and isinstance(item[k][0], list))
  142. ) and (
  143. not (
  144. len(item[k]) > 0
  145. and not (
  146. isinstance(item[k][0], list) or isinstance(item[k][0], dict)
  147. )
  148. )
  149. )
  150. if condition:
  151. for sub_item in item[k]:
  152. standardize(sub_item, v, array_dict_types)
  153. elif isinstance(item[k], dict):
  154. standardize(item[k], v, array_dict_types)
  155. def create_table(data):
  156. """Create a W&B Table.
  157. - Create/decode images from URL/Base64
  158. - Uses spacy to translate NER span data to visualizations.
  159. """
  160. # create table object from columns
  161. table_df = pd.DataFrame(data)
  162. columns = list(table_df.columns)
  163. if ("spans" in table_df.columns) and ("text" in table_df.columns):
  164. columns.append("spans_visual")
  165. if "image" in columns:
  166. columns.append("image_visual")
  167. main_table = wandb.Table(columns=columns)
  168. # Convert to dictionary format to maintain order during processing
  169. matrix = table_df.to_dict(orient="records")
  170. # Import en_core_web_md if exists
  171. en_core_web_md = util.get_module(
  172. "en_core_web_md",
  173. required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
  174. )
  175. nlp = en_core_web_md.load(disable=["ner"])
  176. # Go through each individual row
  177. for _i, document in enumerate(matrix):
  178. # Text NER span visualizations
  179. if ("spans_visual" in columns) and ("text" in columns):
  180. # Add visuals for spans
  181. document["spans_visual"] = None
  182. doc = nlp(document["text"])
  183. ents = []
  184. if ("spans" in document) and (document["spans"] is not None):
  185. for span in document["spans"]:
  186. if ("start" in span) and ("end" in span) and ("label" in span):
  187. charspan = doc.char_span(
  188. span["start"], span["end"], span["label"]
  189. )
  190. ents.append(charspan)
  191. doc.ents = ents
  192. document["spans_visual"] = named_entity(docs=doc)
  193. # Convert image link to wandb Image
  194. if "image" in columns:
  195. # Turn into wandb image
  196. document["image_visual"] = None
  197. if ("image" in document) and (document["image"] is not None):
  198. isurl = urllib.parse.urlparse(document["image"]).scheme in (
  199. "http",
  200. "https",
  201. )
  202. isbase64 = ("data:" in document["image"]) and (
  203. ";base64" in document["image"]
  204. )
  205. if isurl:
  206. # is url
  207. try:
  208. im = Image.open(urllib.request.urlopen(document["image"]))
  209. document["image_visual"] = wandb.Image(im)
  210. except urllib.error.URLError:
  211. wandb.termwarn(f"Image URL {document['image']} is invalid.")
  212. document["image_visual"] = None
  213. elif isbase64:
  214. # is base64 uri
  215. imgb64 = document["image"].split("base64,")[1]
  216. try:
  217. msg = base64.b64decode(imgb64)
  218. buf = io.BytesIO(msg)
  219. im = Image.open(buf)
  220. document["image_visual"] = wandb.Image(im)
  221. except base64.binascii.Error:
  222. wandb.termwarn(f"Base64 string {document['image']} is invalid.")
  223. document["image_visual"] = None
  224. else:
  225. # is data path
  226. document["image_visual"] = wandb.Image(document["image"])
  227. # Create row and append to table
  228. values_list = list(document.values())
  229. main_table.add_data(*values_list)
  230. return main_table
  231. def upload_dataset(dataset_name):
  232. """Upload dataset from local database to Weights & Biases.
  233. Args:
  234. dataset_name: The name of the dataset in the Prodigy database.
  235. """
  236. # Check if wandb.init has been called
  237. if wandb.run is None:
  238. raise ValueError("You must call wandb.init() before upload_dataset()")
  239. with wb_telemetry.context(run=wandb.run) as tel:
  240. tel.feature.prodigy = True
  241. prodigy_db = util.get_module(
  242. "prodigy.components.db",
  243. required="`prodigy` library is required but not installed. Please see https://prodi.gy/docs/install",
  244. )
  245. # Retrieve and upload prodigy dataset
  246. database = prodigy_db.connect()
  247. data = database.get_dataset(dataset_name)
  248. array_dict_types = []
  249. schema = get_schema(data, {}, array_dict_types)
  250. for i, _d in enumerate(data):
  251. standardize(data[i], schema, array_dict_types)
  252. table = create_table(data)
  253. wandb.log({dataset_name: table})
  254. wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")