#1 Restriction to mongo collection

開啟中
tanja 請求將 4 次代碼提交從 tanja/restriction_to_collection 合併至 tanja/master
共有 3 個文件被更改,包括 118 次插入53 次删除
  1. 36 15
      cdplib/db_migration/MigrationCleaning.py
  2. 12 13
      cdplib/db_migration/ParseJsonSchema.py
  3. 70 25
      cdplib/db_migration/ParseMapping.py

+ 36 - 15
cdplib/db_migration/MigrationCleaning.py

@@ -5,7 +5,6 @@ Created on Wed Sep 25 08:09:52 2019
 
 @author: tanya
 """
-
 import os
 import sys
 import pandas as pd
@@ -30,8 +29,9 @@ class MigrationCleaning:
                  schema_paths: (str, list),
                  inconsist_report_table: str = None,
                  filter_index_columns: (str, list) = None,
-                 mapping_source: str = "internal_name",
-                 mapping_target: str = "mongo_name",
+                 mapping_source_name_tag: str = "internal_name",
+                 mapping_target_name_tag: str = "mongo_name",
+                 target_collection_name: str = None,
                  mapping_parser: type = ParseMapping,
                  schema_parser: type = ParseJsonSchema):
         '''
@@ -39,21 +39,41 @@ class MigrationCleaning:
         self.log = Log('Migration Cleaning')
         self._exception_handler = ExceptionsHandler()
 
-        assert isinstance(inconsist_report_table, str),\
-            "Inconsistent report table should be a tablename string"
+        if inconsist_report_table is not None:
+            assert isinstance(inconsist_report_table, str),\
+                "Inconsistent report table should be a tablename string"
 
         self._inconsist_report_table = inconsist_report_table
 
-        assert isinstance(filter_index_columns, (str, list)),\
-            "Filter index columns must be a str or a list"
+        if filter_index_columns is not None:
+            assert isinstance(filter_index_columns, (str, list)),\
+                "Filter index columns must be a str or a list"
+
+            self._filter_index_columns = list(filter_index_columns)
 
-        self._filter_index_columns = list(filter_index_columns)
+        else:
 
-        self._schema_parser = schema_parser(schema_paths)
+            self._filter_index_columns = None
 
         self._mapping_parser = mapping_parser(mapping_path,
-                                              source=mapping_source,
-                                              target=mapping_target)
+                                              source_name_tag=mapping_source_name_tag,
+                                              target_name_tag=mapping_target_name_tag,
+                                              target_collection_name=target_collection_name)
+
+        if target_collection_name is not None:
+
+            schema_names = [os.path.basename(schema_path) for schema_path in schema_paths]
+
+            convention_schema_name = "schema_" + target_collection_name + ".json"
+
+            if convention_schema_name not in schema_names:
+                self._log.log_and_raise_warning("Found no matching of the collection name {0} in schema paths {1}"
+                                                .format(target_collection_name, schema_paths))
+            else:
+                self._schema_parser = schema_parser(schema_paths[schema_names.index(convention_schema_name)])
+
+        else:
+            self._schema_parser = schema_parser(schema_paths)
 
         self._mapping_path = mapping_path
         self._schema_paths = schema_paths
@@ -68,7 +88,7 @@ class MigrationCleaning:
             "Parameter 'data' must be a pandas dataframe"
 
     @property
-    def _field_mapping(self):
+    def _field_mapping(self, collection_name: str = None):
         '''
         '''
         return self._mapping_parser.get_field_mapping()
@@ -503,6 +523,7 @@ class MigrationCleaning:
 
         return data
 
+    """
     def restrict_to_collection(self, data: pd.DataFrame, collection_name: str) -> pd.DataFrame:
         '''
         '''
@@ -511,6 +532,7 @@ class MigrationCleaning:
         fields = self._mapping_parser.get_fields_restricted_to_collecton(collection_name=collection_name)
 
         return data[[c for c in data.columns if (c in fields) or (c in mongo_fields)]]
+    """
 
 
 if __name__ == "__main__":
@@ -532,8 +554,8 @@ if __name__ == "__main__":
         cleaner = MigrationCleaning(
                 mapping_path=mapping_path,
                 schema_paths=schema_paths,
-                mapping_source="internal_name",
-                mapping_target="mongo_name",
+                mapping_source_name_tag="internal_name",
+                mapping_target_name_tag="mongo_name",
                 filter_index_columns=["radsatznummer"],
                 inconsist_report_table=inconsist_report_table)
 
@@ -562,4 +584,3 @@ if __name__ == "__main__":
         data = cleaner.filter_notallowed_values(data)
 
     print("Done!")
-    

+ 12 - 13
cdplib/db_migration/ParseJsonSchema.py

@@ -49,13 +49,13 @@ class ParseJsonSchema(ParseDbSchema):
         for schema_path in schema_paths:
             try:
                 with open(schema_path, "r") as f:
-                   schema = json.load(f) 
+                   schema = json.load(f)
                 # Load schmea dereferenced and cleaned by default values
                 self.schemas.append(self._dereference_schema(schema))
 
             except Exception as e:
-                err = ("Could not load json schema, "
-                       "Obtained error {}".format(e))
+                err = ("Could not load json schema {0}, "
+                       "Obtained error {1}".format(schema_path, e))
 
                 self._log.error(err)
                 raise Exception(err)
@@ -66,7 +66,7 @@ class ParseJsonSchema(ParseDbSchema):
         '''
         # Don't use strip() instaed of replace since schema_c.strip(schema_)
         # will discard the c as well which is not a appropriate output
-        return [os.path.basename(p).replace("schema_","").split(".")[0] for p in self._schema_paths]
+        return [os.path.basename(p).replace("schema_", "").split(".")[0] for p in self._schema_paths]
 
     def get_fields(self) -> list:
         '''
@@ -338,7 +338,7 @@ class ParseJsonSchema(ParseDbSchema):
 
         assert(isinstance(schema, dict)),\
             "Parameter 'schema' must be a dictionary type"
-            
+
         base_dir_url = Path(os.path.join(os.getcwd(), "mongo_schema")).as_uri() + '/'
         schema = jsonref.loads(str(schema).replace("'", "\""), base_uri=base_dir_url)
         schema = deepcopy(schema)
@@ -354,11 +354,11 @@ class ParseJsonSchema(ParseDbSchema):
         if 'default_values' in schema:
             del schema['default_values']
         return schema
-    
+
         assert(isinstance(schema, dict)),\
         "Parameter 'schema' must be a dictionary type"
-    
-    # Need to parse schmema for importing to mongo db 
+
+    # Need to parse schmema for importing to mongo db
     # Reason:
     # We need to drop default values since MongoDB can't handle them
     # We need to deference json before import to Mongo DB pymongo can't deal with references
@@ -379,8 +379,8 @@ class ParseJsonSchema(ParseDbSchema):
             schema = self._dereference_schema(schema)
 
         return schema
-           
-   
+
+
     def _analyze_schema(self, schema: dict, definitions_flag: bool = False) -> dict:
 
 
@@ -401,10 +401,10 @@ class ParseJsonSchema(ParseDbSchema):
 
 if __name__ == "__main__":
 
-#     Only for testing
+    # Only for testing
 
     schema_path = os.path.join(".", "mongo_schema", "schema_components.json")
-    
+
     if os.path.isfile(schema_path):
 
         parse_obj = ParseJsonSchema(schema_paths=schema_path)
@@ -424,4 +424,3 @@ if __name__ == "__main__":
         allowed_values = parse_obj.get_allowed_values()
 
         descriptions = parse_obj.get_field_descriptions()
-    

+ 70 - 25
cdplib/db_migration/ParseMapping.py

@@ -15,49 +15,85 @@ class ParseMapping:
     '''
     '''
     def __init__(self, mapping_path: str, log_name: str = "ParseMapping",
-                 source: str = "original_name", target: str = "mongo_name",
-                 target_collection: str = "mongo_collection"):
+                 source_name_tag: str = "original_name", target_name_tag: str = "mongo_name",
+                 target_collection_tag: str = "mongo_collection",
+                 target_collection_name: str = None):
         '''
         '''
         import json
         from cdplib.log import Log
 
-        self.log = Log('Parse Mapping')
+        self._log = Log('Parse Mapping')
 
         if not os.path.isfile(mapping_path):
-            err = "Mapping not found"
-            self._log.error(err)
-            raise FileNotFoundError(err)
-
+            self._log.log_and_raise_error("Mapping not found")
         try:
             with open(mapping_path, "r") as f:
                 self._mapping = json.load(f)
 
         except Exception as e:
-            err = ("Could not load mapping. "
-                   "Exit with error {}".format(e))
-            self._log.error(err)
-            raise Exception(err)
+            self._log.log_and_raise_error("Could not load mapping. "
+                                          "Exit with error {}".format(e))
+
+        self._mapping_path = mapping_path
+        self._source_name_tag = source_name_tag
+        self._target_name_tag = target_name_tag
+        self._target_collection_tag = target_collection_tag
+        self._target_collection_name = target_collection_name
+
+        self._restrict_mapping_to_collection()
+
+    def _restrict_mapping_to_collection(self):
+        '''
+        '''
+        if self._target_collection_name is not None:
+
+            result = []
+
+            for d in self._mapping:
+
+                if self._target_collection_tag not in d:
+                    continue
 
-        self._source = source
-        self._target = target
-        self._target_collection = target_collection
+                for key in [self._target_name_tag, self._target_collection_tag]:
+                    if not isinstance(d[key], list):
+                        d[key] = [d[key]]
+
+                if (len(d[self._target_collection_tag]) > 1) and (len(d[self._target_name_tag]) == 1):
+                    d[self._target_name_tag] = d[self._target_name_tag]*len(d[self._target_collection_tag])
+
+                if len(d[self._target_collection_tag]) != len(d[self._target_name_tag]):
+                    self._log.log_and_raise_error(("In the migration mapping '{0}' {1} "
+                                                   "{2} has an unclear collection")
+                                                   .format(self._mapping_path,
+                                                           self._target_name_tag,
+                                                           d[self._target_name_tag]))
+
+                if self._target_collection_name in d[self._target_collection_tag]:
+
+                    d[self._target_name_tag] = d[self._target_name_tag][d[self._target_collection_tag].index(self._target_collection_name)]
+
+                    d[self._target_collection_tag] = self._target_collection_name
+
+                    result.append(d)
+
+            self._mapping = result
 
     def get_field_mapping(self) -> dict:
         '''
         '''
-        assert(all([set([self._source, self._target]) <= set(d)
+        assert(all([set([self._source_name_tag, self._target_name_tag]) <= set(d)
                     for d in self._mapping]))
 
-        return {d[self._source]: d[self._target] for d in self._mapping}
+        return {d[self._source_name_tag]: d[self._target_name_tag] for d in self._mapping}
 
     def _get_fields_satistisfying_condition(self, key: str, value) -> list:
         '''
         '''
-        assert(all([self._source in d for d in self._mapping])),\
+        assert(all([self._source_name_tag in d for d in self._mapping])),\
             "Invalid from field"
 
-        return [d[self._source] for d in self._mapping
+        return [d[self._source_name_tag] for d in self._mapping
                 if (key in d) and (d[key] == value)]
 
     def get_required_fields(self) -> list:
@@ -72,19 +108,28 @@ class ParseMapping:
         return self._get_fields_satistisfying_condition(key="type",
                                                         value="Date")
 
+    """
     def get_fields_restricted_to_collecton(self, collection_name: str) -> list:
         '''
         '''
-        return self._get_fields_satistisfying_condition(key=self._target_collection,
-                                                        value=collection_name)
+        target_collection_tag_mapping = {d[self._source_name_tag]: d[self._target_collection_tag]
+                                     for d in self._mapping
+                                     if (self._target_collection_tag in d)}
+
+        target_collection_tag_mapping = {k: v if isinstance(v, list) else [v]
+                                     for k, v in target_collection_tag_mapping.items()}
+
+        return [k for k,v in target_collection_tag_mapping.items()
+                if collection_name in v]
+    """
 
     def _get_info(self, key: str, value=None) -> dict:
         '''
         '''
-        assert(all([self._source in d for d in self._mapping])),\
+        assert(all([self._source_name_tag in d for d in self._mapping])),\
             "Invalid from field"
 
-        return {d[self._source]: d[key] for d in self._mapping
+        return {d[self._source_name_tag]: d[key] for d in self._mapping
                 if (key in d) and ((value is not None)
                 and (d[key] == value)) or (key in d)}
 
@@ -134,7 +179,7 @@ class ParseMapping:
         else:
             err = ("Incorrectly filled mapping. Column numbers should ",
                    "either in all or in neither of the fields")
-            self.log.err(err)
+            self._log.err(err)
             raise Exception(err)
 
         return column_numbers
@@ -148,8 +193,8 @@ if __name__ == "__main__":
 
         print("found mapping path")
 
-        parser = ParseMapping(mapping_path, source="internal_name",
-                              target="mongo_name")
+        parser = ParseMapping(mapping_path, source_name_tag="internal_name",
+                              target_name_tag="mongo_name")
 
         internal_to_mongo_mapping = parser.get_field_mapping()