#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Sep 20 15:33:17 2019 @author: tanya """ import os import sys import numpy as np sys.path.append(os.getcwd()) class ParseMapping: ''' ''' def __init__(self, mapping_path: str, log_name: str = "ParseMapping", source: str = "original_name", target: str = "mongo_name", target_collection: str = "mongo_collection"): ''' ''' import json from cdplib.log import Log self._log = Log('Parse Mapping') if not os.path.isfile(mapping_path): err = "Mapping not found "+mapping_path self._log.error(err) raise FileNotFoundError(err) try: with open(mapping_path, "r") as f: self._mapping = json.load(f) except Exception as e: err = ("Could not load mapping. " + mapping_path + "Exit with error {}".format(e)) self._log.error(err) raise Exception(err) self._source = source self._target = target self._target_collection = target_collection def get_field_mapping(self) -> dict: ''' ''' assert(all([set([self._source, self._target]) <= set(d) for d in self._mapping])) return {d[self._source]: d[self._target] for d in self._mapping} def _get_fields_satistisfying_condition(self, key: str, value) -> list: ''' ''' assert(all([self._source in d for d in self._mapping])),\ "Invalid from field" return [d[self._source] for d in self._mapping if (key in d) and (d[key] == value)] def get_required_fields(self) -> list: ''' ''' return self._get_fields_satistisfying_condition(key="required", value=1) def get_date_fields(self) -> list: ''' ''' return self._get_fields_satistisfying_condition(key="type", value="Date") def get_fields_restricted_to_collecton(self, collection_name: str) -> list: ''' ''' return self._get_fields_satistisfying_condition(key=self._target_collection, value=collection_name) def _get_info(self, key: str, value=None) -> dict: ''' ''' assert(all([self._source in d for d in self._mapping])),\ "Invalid from field" return {d[self._source]: d[key] for d in self._mapping if (key in d) and ((value is not None) and (d[key] == value)) or (key in d)} def get_default_values(self) -> dict: ''' ''' return self._get_info(key="default_values") def get_date_formats(self) -> dict: ''' ''' return self._get_info(key="date_format") def get_internal_names(self) -> dict: ''' ''' if all(["internal_name" in d for d in self._mapping]): internal_names = [d["internal_name"] for d in self._mapping] elif all(["internal_name" not in d for d in self._mapping]): internal_names = list(range(len(self._mapping))) else: err = ("Incorrectly filled mapping. Internal names should " "either be in all or in neither of the fields") self._log.error(err) raise Exception(err) return internal_names def get_mongo_names(self) -> dict: ''' ''' if all(["mongo_name" in d for d in self._mapping]): mongo_names = [d["mongo_name"] for d in self._mapping] elif all(["mongo_name" not in d for d in self._mapping]): mongo_names = list(range(len(self._mapping))) else: err = ("Incorrectly filled mapping. Mongo names should " "either be in all or in neither of the fields") self._log.error(err) raise Exception(err) return mongo_names def get_types(self) -> dict: ''' ''' return self._get_info(key="type") def get_python_types(self) -> dict: ''' ''' sql_to_python_dtypes = { "Text": str, "Date": np.dtype(' dict: ''' ''' return self._get_info(key="value_mapping") def get_column_numbers(self) -> list: ''' ''' if all(["column_number" in d for d in self._mapping]): column_numbers = [d["column_number"] for d in self._mapping] elif all(["column_number" not in d for d in self._mapping]): column_numbers = list(range(len(self._mapping))) else: err = ("Incorrectly filled mapping. Column numbers should ", "either in all or in neither of the fields") self._log.err(err) raise Exception(err) return column_numbers if __name__ == "__main__": mapping_path = os.path.join(".", "migration_mappings", "rs0_mapping.json") if os.path.isfile(mapping_path): print("found mapping path") parser = ParseMapping(mapping_path, source="internal_name", target="mongo_name") internal_to_mongo_mapping = parser.get_field_mapping() original_to_internal_mapping = parser.get_field_mapping() default_values = parser.get_default_values() types = parser.get_types() column_numbers = parser.get_column_numbers() print("Done testing!")