ParseMapping.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Sep 20 15:33:17 2019
  5. @author: tanya
  6. """
  7. import os
  8. import sys
  9. import numpy as np
  10. sys.path.append(os.getcwd())
  11. class ParseMapping:
  12. '''
  13. '''
  14. def __init__(self, mapping_path: str, log_name: str = "ParseMapping",
  15. source_name_tag: str = "original_name", target_name_tag: str = "mongo_name",
  16. target_collection_tag: str = "mongo_collection",
  17. target_collection_name: str = None):
  18. '''
  19. '''
  20. import json
  21. from cdplib.log import Log
  22. self._log = Log('Parse Mapping')
  23. if not os.path.isfile(mapping_path):
  24. self._log.log_and_raise_error("Mapping not found")
  25. try:
  26. with open(mapping_path, "r") as f:
  27. self._mapping = json.load(f)
  28. except Exception as e:
  29. self._log.log_and_raise_error("Could not load mapping. "
  30. "Exit with error {}".format(e))
  31. self._mapping_path = mapping_path
  32. self._source_name_tag = source_name_tag
  33. self._target_name_tag = target_name_tag
  34. self._target_collection_tag = target_collection_tag
  35. self._target_collection_name = target_collection_name
  36. self._restrict_mapping_to_collection()
  37. def _restrict_mapping_to_collection(self):
  38. '''
  39. '''
  40. if self._target_collection_name is not None:
  41. result = []
  42. for d in self._mapping:
  43. if self._target_collection_tag not in d:
  44. continue
  45. for key in [self._target_name_tag, self._target_collection_tag]:
  46. if not isinstance(d[key], list):
  47. d[key] = [d[key]]
  48. if (len(d[self._target_collection_tag]) > 1) and (len(d[self._target_name_tag]) == 1):
  49. d[self._target_name_tag] = d[self._target_name_tag]*len(d[self._target_collection_tag])
  50. if len(d[self._target_collection_tag]) != len(d[self._target_name_tag]):
  51. self._log.log_and_raise_error(("In the migration mapping '{0}' {1} "
  52. "{2} has an unclear collection")
  53. .format(self._mapping_path,
  54. self._target_name_tag,
  55. d[self._target_name_tag]))
  56. if self._target_collection_name in d[self._target_collection_tag]:
  57. d[self._target_name_tag] = d[self._target_name_tag][d[self._target_collection_tag].index(self._target_collection_name)]
  58. d[self._target_collection_tag] = self._target_collection_name
  59. result.append(d)
  60. self._mapping = result
  61. def get_field_mapping(self) -> dict:
  62. '''
  63. '''
  64. assert(all([set([self._source_name_tag, self._target_name_tag]) <= set(d)
  65. for d in self._mapping]))
  66. return {d[self._source_name_tag]: d[self._target_name_tag] for d in self._mapping}
  67. def _get_fields_satistisfying_condition(self, key: str, value) -> list:
  68. '''
  69. '''
  70. assert(all([self._source_name_tag in d for d in self._mapping])),\
  71. "Invalid from field"
  72. return [d[self._source_name_tag] for d in self._mapping
  73. if (key in d) and (d[key] == value)]
  74. def get_required_fields(self) -> list:
  75. '''
  76. '''
  77. return self._get_fields_satistisfying_condition(key="required",
  78. value=1)
  79. def get_date_fields(self) -> list:
  80. '''
  81. '''
  82. return self._get_fields_satistisfying_condition(key="type",
  83. value="Date")
  84. """
  85. def get_fields_restricted_to_collecton(self, collection_name: str) -> list:
  86. '''
  87. '''
  88. target_collection_tag_mapping = {d[self._source_name_tag]: d[self._target_collection_tag]
  89. for d in self._mapping
  90. if (self._target_collection_tag in d)}
  91. target_collection_tag_mapping = {k: v if isinstance(v, list) else [v]
  92. for k, v in target_collection_tag_mapping.items()}
  93. return [k for k,v in target_collection_tag_mapping.items()
  94. if collection_name in v]
  95. """
  96. def _get_info(self, key: str, value=None) -> dict:
  97. '''
  98. '''
  99. assert(all([self._source_name_tag in d for d in self._mapping])),\
  100. "Invalid from field"
  101. return {d[self._source_name_tag]: d[key] for d in self._mapping
  102. if (key in d) and ((value is not None)
  103. and (d[key] == value)) or (key in d)}
  104. def get_default_values(self) -> dict:
  105. '''
  106. '''
  107. return self._get_info(key="default_values")
  108. def get_date_formats(self) -> dict:
  109. '''
  110. '''
  111. return self._get_info(key="date_format")
  112. def get_types(self) -> dict:
  113. '''
  114. '''
  115. return self._get_info(key="type")
  116. def get_python_types(self) -> dict:
  117. '''
  118. '''
  119. sql_to_python_dtypes = {
  120. "Text": str,
  121. "Date": np.dtype('<M8[ns]'),
  122. "Double": float,
  123. "Integer": int
  124. }
  125. sql_types = self.get_types()
  126. return {k: sql_to_python_dtypes[v] for k, v in sql_types.items()}
  127. def get_value_mappings(self) -> dict:
  128. '''
  129. '''
  130. return self._get_info(key="value_mapping")
  131. def get_column_numbers(self) -> list:
  132. '''
  133. '''
  134. if all(["column_number" in d for d in self._mapping]):
  135. column_numbers = [d["column_number"] for d in self._mapping]
  136. elif all(["column_number" not in d for d in self._mapping]):
  137. column_numbers = list(range(len(self._mapping)))
  138. else:
  139. err = ("Incorrectly filled mapping. Column numbers should ",
  140. "either in all or in neither of the fields")
  141. self._log.err(err)
  142. raise Exception(err)
  143. return column_numbers
  144. if __name__ == "__main__":
  145. mapping_path = os.path.join(".", "migration_mappings", "rs0_mapping.json")
  146. if os.path.isfile(mapping_path):
  147. print("found mapping path")
  148. parser = ParseMapping(mapping_path, source_name_tag="internal_name",
  149. target_name_tag="mongo_name")
  150. internal_to_mongo_mapping = parser.get_field_mapping()
  151. original_to_internal_mapping = parser.get_field_mapping()
  152. default_values = parser.get_default_values()
  153. types = parser.get_types()
  154. column_numbers = parser.get_column_numbers()
  155. print("Done testing!")