MigrationCleaning.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Sep 25 08:09:52 2019
  5. @author: tanya
  6. """
  7. import os
  8. import sys
  9. import pandas as pd
  10. import numpy as np
  11. import gc
  12. sys.path.append(os.getcwd())
  13. from libraries.db_migration.ParseMapping import ParseMapping
  14. from libraries.db_migration.ParseJsonSchema import ParseJsonSchema
  15. from libraries.utils.ClassLogging import ClassLogging
  16. from libraries.utils.CleaningUtils import CleaningUtils
  17. class MigrationCleaning(ClassLogging):
  18. '''
  19. Class for correcting and filtering the incorrect data.
  20. We keep the correcting and the filtering methods separated,
  21. since there might be other custom steps in between.
  22. '''
  23. def __init__(self, mapping_path: str,
  24. schema_paths: (str, list),
  25. inconsist_report_table: str = None,
  26. filter_index_columns: (str, list) = None,
  27. mapping_source: str = "internal_name",
  28. mapping_target: str = "mongo_name",
  29. mapping_parser: type = ParseMapping,
  30. schema_parser: type = ParseJsonSchema,
  31. log_name: str = "MigrationCleaning"):
  32. '''
  33. '''
  34. super().__init__(log_name=log_name)
  35. assert isinstance(inconsist_report_table, str),\
  36. "Inconsistent report table should be a tablename string"
  37. self._inconsist_report_table = inconsist_report_table
  38. assert isinstance(filter_index_columns, (str, list)),\
  39. "Filter index columns must be a str or a list"
  40. self._filter_index_columns = list(filter_index_columns)
  41. self._schema_parser = schema_parser(schema_paths)
  42. self._mapping_parser = mapping_parser(mapping_path,
  43. source=mapping_source,
  44. target=mapping_target)
  45. self._mapping_path = mapping_path
  46. self._schema_paths = schema_paths
  47. def _assert_dataframe_input(self, data: pd.DataFrame):
  48. '''
  49. '''
  50. assert(isinstance(data, pd.DataFrame)),\
  51. "Parameter 'data' must be a pandas dataframe"
  52. @property
  53. def _field_mapping(self):
  54. '''
  55. '''
  56. return self._mapping_parser.get_field_mapping()
  57. @property
  58. def _required_fields(self):
  59. '''
  60. '''
  61. source_required_fields = self._mapping_parser.get_required_fields()
  62. target_required_fields = self._schema_parser.get_required_fields()
  63. for source_field, target_field in self._field_mapping.items():
  64. if (target_field in target_required_fields) and\
  65. (source_field not in source_required_fields):
  66. source_required_fields.append(source_field)
  67. return source_required_fields
  68. @property
  69. def _default_values(self):
  70. '''
  71. '''
  72. default_values = {}
  73. target_default_values = self._schema_parser.get_default_values()
  74. source_default_values = self._mapping_parser.get_default_values()
  75. for source_field, target_field in self._field_mapping.items():
  76. if source_field not in source_default_values:
  77. continue
  78. elif target_field not in target_default_values:
  79. target_default_values[target_field] = np.nan
  80. default_values[source_field] = {
  81. target_default_values[target_field]:
  82. source_default_values[source_field]
  83. }
  84. return default_values
  85. @property
  86. def _python_types(self):
  87. '''
  88. '''
  89. target_types = self._schema_parser.get_python_types()
  90. result = {}
  91. for source_field, target_field in self._field_mapping.items():
  92. if target_field in target_types:
  93. result[source_field] = target_types[target_field]
  94. """
  95. date_type_mismatch =\
  96. (target_field in target_types) and\
  97. (source_field in source_types) and\
  98. (target_types[target_field] == str) and\
  99. (source_types[source_field] == np.dtype('<M8[ns]'))
  100. if date_type_mismatch:
  101. target_types[target_field] = np.dtype('<M8[ns]')
  102. if (source_field in source_types) and\
  103. (target_field in target_types) and\
  104. (target_types[target_field] != source_types[source_field]):
  105. self.log_and_raise(("Type {0} of field {1} "
  106. "in schema does not match "
  107. "type {2} of field {3} in "
  108. "migration mapping")
  109. .format(target_types[target_field],
  110. target_field,
  111. source_types[source_field],
  112. source_field))
  113. if target_field in target_types:
  114. source_types[source_field] = target_types[target_field]
  115. """
  116. return result
  117. @property
  118. def _value_mappings(self):
  119. '''
  120. '''
  121. return self._mapping_parser.get_value_mappings()
  122. @property
  123. def _date_formats(self):
  124. '''
  125. '''
  126. return self._mapping_parser.get_date_formats()
  127. def _get_mongo_schema_info(self, method_name: str):
  128. '''
  129. '''
  130. result = {}
  131. target_dict = getattr(self._schema_parser, method_name)()
  132. for source_field, target_field in self._field_mapping.items():
  133. if target_field in target_dict:
  134. result[source_field] = target_dict[target_field]
  135. return result
  136. @property
  137. def _allowed_values(self):
  138. '''
  139. '''
  140. return self._get_mongo_schema_info("get_allowed_values")
  141. @property
  142. def _minimum_values(self):
  143. '''
  144. '''
  145. return self._get_mongo_schema_info("get_minimum_value")
  146. @property
  147. def _maximum_values(self):
  148. '''
  149. '''
  150. return self._get_mongo_schema_info("get_maximum_value")
  151. @property
  152. def _patterns(self):
  153. '''
  154. '''
  155. return self._get_mongo_schema_info("get_patterns")
  156. def _filter_invalid_data(self, data: pd.DataFrame,
  157. invalid_mask: pd.Series,
  158. reason: (str, pd.Series)) -> pd.DataFrame:
  159. '''
  160. '''
  161. from libraries.db_handlers.SQLHandler import SQLHandler
  162. assert((self._inconsist_report_table is not None) and
  163. (self._filter_index_columns is not None)),\
  164. "Inconsistent report table or filter index is not provided"
  165. self._assert_dataframe_input(data)
  166. data = data.copy(deep=True)
  167. db = SQLHandler()
  168. if invalid_mask.sum() == 0:
  169. return data
  170. data_inconsist = data.assign(reason=reason)\
  171. .loc[invalid_mask]\
  172. .reset_index(drop=True)
  173. db.append_to_table(data=data_inconsist,
  174. tablename=self._inconsist_report_table)
  175. n_rows_filtered = len(data_inconsist)
  176. n_instances_filtered = len(data_inconsist[self._filter_index_columns].drop_duplicates())
  177. del data_inconsist
  178. gc.collect()
  179. self._log.warning(("Filtering: {0} ."
  180. "Filtered {1} rows "
  181. "and {2} instances"
  182. .format(reason, n_rows_filtered, n_instances_filtered)))
  183. nok_index_data = data.loc[invalid_mask, self._filter_index_columns]\
  184. .drop_duplicates().reset_index(drop=True)
  185. nok_index = pd.MultiIndex.from_arrays([nok_index_data[c] for c in
  186. self._filter_index_columns])
  187. all_index = pd.MultiIndex.from_arrays([data[c] for c in
  188. self._filter_index_columns])
  189. data = data.loc[~all_index.isin(nok_index)].reset_index(drop=True)
  190. return data
  191. def _replace_values(self, data: pd.DataFrame,
  192. default: bool) -> pd.DataFrame:
  193. '''
  194. '''
  195. if default:
  196. default_str = "default"
  197. else:
  198. default_str = "equal"
  199. self._assert_dataframe_input(data)
  200. data = data.copy(deep=True)
  201. if default:
  202. mapping = self._default_values
  203. else:
  204. mapping = self._value_mappings
  205. for column, d in mapping.items():
  206. try:
  207. if column not in data.columns:
  208. continue
  209. dtype = data[column].dtype
  210. for key, values in d.items():
  211. if not default:
  212. mask = (data[column].astype(str).isin(values))
  213. else:
  214. mask = (data[column].isin(values))
  215. if default:
  216. mask = mask | (data[column].isnull())
  217. data.loc[mask, column] = key
  218. data[column] = data[column].astype(dtype)
  219. except Exception as e:
  220. self.log_and_raise(("Failed to replace {0} values "
  221. "in {1}. Exit with error {2}"
  222. .format(default_str, column, e)))
  223. self._log.info("Replaced {} values".format(default_str))
  224. return data
  225. def replace_default_values(self, data: pd.DataFrame) -> pd.DataFrame:
  226. '''
  227. '''
  228. return self._replace_values(data=data, default=True)
  229. def map_equal_values(self, data: pd.DataFrame) -> pd.DataFrame:
  230. '''
  231. '''
  232. return self._replace_values(data=data, default=False)
  233. def convert_types(self, data: pd.DataFrame) -> pd.DataFrame:
  234. '''
  235. '''
  236. self._assert_dataframe_input(data)
  237. for column, python_type in self._python_types.items():
  238. try:
  239. if column not in data.columns:
  240. continue
  241. elif column in self._date_formats:
  242. data[column] = CleaningUtils.convert_dates(
  243. series=data[column],
  244. formats=self._date_formats[column])
  245. elif (python_type == int) and data[column].isnull().any():
  246. self.log_and_raise(("Column {} contains missing values "
  247. "and cannot be of integer type"
  248. .format(column)))
  249. elif python_type == str:
  250. python_type = object
  251. else:
  252. data[column] = data[column].astype(python_type)
  253. if data[column].dtype != python_type:
  254. self._log.warning(("After conversion type in {0} "
  255. "should be {1} "
  256. "but is still {2}"
  257. .format(column,
  258. python_type,
  259. data[column].dtype)))
  260. except Exception as e:
  261. self.log_and_raise(("Failed to convert types in {0}. "
  262. "Exit with error {1}"
  263. .format(column, e)))
  264. self._log.info("Converted dtypes")
  265. return data
  266. def filter_invalid_null_values(self, data: pd.DataFrame) -> pd.DataFrame:
  267. '''
  268. '''
  269. self._assert_dataframe_input(data)
  270. for column in data.columns:
  271. if (column in self._required_fields) and\
  272. (data[column].isnull().any()):
  273. invalid_mask = data[column].isnull()
  274. reason = "Null value in the required field {}"\
  275. .format(column)
  276. data = self._filter_invalid_data(data=data,
  277. invalid_mask=invalid_mask,
  278. reason=reason)
  279. return data
  280. def filter_invalid_types(self, data: pd.DataFrame) -> pd.DataFrame():
  281. '''
  282. '''
  283. self._assert_dataframe_input(data)
  284. for column, python_type in self._python_types.items():
  285. if data[column].dtype != python_type:
  286. def mismatch_type(x):
  287. return type(x) != python_type
  288. invalid_mask = data[column].apply(mismatch_type)
  289. reason = "Type mismatch if field {}".format(column)
  290. data = self._filter_invalid_data(data=data,
  291. invalid_mask=invalid_mask,
  292. reason=reason)
  293. return data
  294. def filter_invalid_patterns(self, data: pd.DataFrame) -> pd.DataFrame:
  295. '''
  296. '''
  297. self._assert_dataframe_input(data)
  298. for column, pattern in self._patterns:
  299. invalid_mask = (~data[column].astype(str).str.match(pattern))
  300. reason = "Pattern mismatch in field {}".format(column)
  301. data = self._filter_invalid_data(data=data,
  302. invalid_mask=invalid_mask,
  303. reason=reason)
  304. return data
  305. def filter_notallowed_values(self, data: pd.DataFrame) -> pd.DataFrame:
  306. '''
  307. '''
  308. for column, value in self._minimum_values.items():
  309. invalid_mask = data[column] > value
  310. reason = "Too large values in field {}".format(column)
  311. data = self._filter_invalid_data(data=data,
  312. invalid_mask=invalid_mask,
  313. reason=reason)
  314. for column, value in self._maximum_values.items():
  315. invalid_mask = data[column] < value
  316. reason = "Too small values in field {}".format(column)
  317. data = self._filter_invalid_data(data=data,
  318. invalid_mask=invalid_mask,
  319. reason=reason)
  320. for column, allowed_values in self._allowed_values.items():
  321. invalid_mask = (~data[column].isin(allowed_values))
  322. reason = "Too small values in field {}".format(column)
  323. data = self._filter_invalid_data(data=data,
  324. invalid_mask=invalid_mask,
  325. reason=reason)
  326. return data
  327. if __name__ == "__main__":
  328. # testing
  329. from libraries.db_handlers.SQLHandler import SQLHandler
  330. mapping_path = os.path.join(".", "migration_mappings", "rs1_mapping.json")
  331. schema_paths = [
  332. os.path.join(".", "mongo_schema", "schema_wheelsets.json"),
  333. os.path.join(".", "mongo_schema", "schema_process_instances.json")]
  334. inconsist_report_table = "test_inconsist_report_rs1"
  335. if all([os.path.isfile(p) for p in schema_paths + [mapping_path]]):
  336. print("Found schemas!")
  337. cleaner = MigrationCleaning(
  338. mapping_path=mapping_path,
  339. schema_paths=schema_paths,
  340. mapping_source="internal_name",
  341. mapping_target="mongo_name",
  342. filter_index_columns=["radsatznummer"],
  343. inconsist_report_table=inconsist_report_table)
  344. db = SQLHandler()
  345. data = db.read_sql_to_dataframe("select * from rs1 limit 100")
  346. data = cleaner.replace_default_values(data)
  347. data = cleaner.map_equal_values(data)
  348. data = cleaner.convert_types(data)
  349. non_filtered_len = len(data)
  350. data = cleaner.filter_invalid_types(data)
  351. if len(data) < non_filtered_len:
  352. data = cleaner.convert_types(data)
  353. data = cleaner.filter_invalid_null_values(data)
  354. data = cleaner.filter_invalid_patterns(data)
  355. data = cleaner.filter_notallowed_values(data)
  356. print("Done!")