#2 Added a new class with useful sql queries: merge after restriction_to_collection

開啟中
tanja 請求將 8 次代碼提交從 tanja/sql_queries_class 合併至 tanja/master

+ 108 - 0
cdplib/db_handlers/SQLQueries.py

@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Dec  5 10:26:27 2019
+
+@author: tanya
+"""
+
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+
+class SQLQueries:
+    '''
+    Summary of useful complicated sql queries
+    '''
+    def __init__(self):
+        '''
+        '''
+        from cdplib.log import Log
+
+        self._log = Log("SQLQueries")
+
+    def get_combinations(self, instances: dict, table: str, columns: list = None) -> str:
+        '''
+        Returns a query to read data corresponding to combinations of given values for given columns
+         specified in the argument instances.
+
+        :param instances: a dictionary of form {"column_name_1": [val_1, val2, val3], "column_name_2": [val]}
+        :param table: table name or subquery
+        :param columns: list of column names to read from the table
+        '''
+        from copy import deepcopy
+
+        instances = deepcopy(instances)
+
+        if len(instances) == 0:
+            return table[1:-1] # removed external paranthesis
+
+        else:
+
+            column, values = list(instances.items())[0]
+
+            instances.pop(column)
+
+            if table.startswith("(SELECT"):
+                prefix = column + "_subquery."
+                sub_query_name = " " + prefix[:-1]  # removed the dot in the end
+            else:
+                prefix = ""
+                sub_query_name = prefix
+
+            if columns is None:
+                columns_subquery = "*"
+            else:
+                columns_with_prefix = [prefix + c for c in columns]
+
+                columns_subquery = ", ".join(columns_with_prefix)
+
+            if len(values) > 1:
+                table = "(SELECT DISTINCT {0} FROM {1}{2} WHERE {3}{4} IN {5})"\
+                        .format(columns_subquery, table, sub_query_name, prefix, column, tuple(values))
+
+            else:
+                table = "(SELECT DISTINCT {0} FROM {1}{2} WHERE {3}{4} = {5})"\
+                        .format(columns_subquery, table, sub_query_name, prefix, column, values[0])
+
+            return self.get_combinations(instances, table, columns)
+
+    def get_cooccurances(self, instances: dict, table: str, columns: list = None) -> str:
+        '''
+        Returns a query to read data corresponding to cooccurances of given values for given columns
+         specified in the argument instances.
+
+        :param instances: a dictionary of form {"column_name_1": [val_11, val12, val13],
+                                                "column_name_2": [val_21, val22, val23]}
+         (all the lists in the instances need to be of the same length)
+        :param table: table name or subquery
+        :param columns: list of column names to read from the table
+        '''
+        value_length = len(list(instances.values())[0])
+
+        if not all([len(vals) == value_length for vals in instances.values()]):
+            self._log.log_and_raise_error("All the values in instances should be of the same length")
+
+        if columns is None:
+            columns_subquery = "*"
+        else:
+            columns_subquery = ", ".join(columns)
+
+        if len(instances) == 1:
+            column = list(instances.keys())[0]
+
+            values = instances[column]
+
+            where_clause = "{0} IN {1}".format(column, values)
+
+        else:
+            and_statements = [" AND ".join(["{0} = {1}".format(column, instances[column][i])
+                              for column in instances.keys()]) for i in range(value_length)]
+
+            and_statements = ["({})".format(s) for s in and_statements]
+
+            where_clause = " OR ".join(and_statements)
+
+        return "SELECT DISTINCT {0} FROM {1} WHERE {2}".format(columns_subquery, table, where_clause)

+ 2 - 1
cdplib/db_handlers/__init__.py

@@ -1,2 +1,3 @@
 from .MongodbHandler import *
-from .SQLHandler import *
+from .SQLHandler import *
+from .SQLQueries import *

+ 36 - 15
cdplib/db_migration/MigrationCleaning.py

@@ -5,7 +5,6 @@ Created on Wed Sep 25 08:09:52 2019
 
 @author: tanya
 """
-
 import os
 import sys
 import pandas as pd
@@ -30,8 +29,9 @@ class MigrationCleaning:
                  schema_paths: (str, list),
                  inconsist_report_table: str = None,
                  filter_index_columns: (str, list) = None,
-                 mapping_source: str = "internal_name",
-                 mapping_target: str = "mongo_name",
+                 mapping_source_name_tag: str = "internal_name",
+                 mapping_target_name_tag: str = "mongo_name",
+                 target_collection_name: str = None,
                  mapping_parser: type = ParseMapping,
                  schema_parser: type = ParseJsonSchema):
         '''
@@ -39,21 +39,41 @@ class MigrationCleaning:
         self.log = Log('Migration Cleaning')
         self._exception_handler = ExceptionsHandler()
 
-        assert isinstance(inconsist_report_table, str),\
-            "Inconsistent report table should be a tablename string"
+        if inconsist_report_table is not None:
+            assert isinstance(inconsist_report_table, str),\
+                "Inconsistent report table should be a tablename string"
 
         self._inconsist_report_table = inconsist_report_table
 
-        assert isinstance(filter_index_columns, (str, list)),\
-            "Filter index columns must be a str or a list"
+        if filter_index_columns is not None:
+            assert isinstance(filter_index_columns, (str, list)),\
+                "Filter index columns must be a str or a list"
+
+            self._filter_index_columns = list(filter_index_columns)
 
-        self._filter_index_columns = list(filter_index_columns)
+        else:
 
-        self._schema_parser = schema_parser(schema_paths)
+            self._filter_index_columns = None
 
         self._mapping_parser = mapping_parser(mapping_path,
-                                              source=mapping_source,
-                                              target=mapping_target)
+                                              source_name_tag=mapping_source_name_tag,
+                                              target_name_tag=mapping_target_name_tag,
+                                              target_collection_name=target_collection_name)
+
+        if target_collection_name is not None:
+
+            schema_names = [os.path.basename(schema_path) for schema_path in schema_paths]
+
+            convention_schema_name = "schema_" + target_collection_name + ".json"
+
+            if convention_schema_name not in schema_names:
+                self._log.log_and_raise_warning("Found no matching of the collection name {0} in schema paths {1}"
+                                                .format(target_collection_name, schema_paths))
+            else:
+                self._schema_parser = schema_parser(schema_paths[schema_names.index(convention_schema_name)])
+
+        else:
+            self._schema_parser = schema_parser(schema_paths)
 
         self._mapping_path = mapping_path
         self._schema_paths = schema_paths
@@ -68,7 +88,7 @@ class MigrationCleaning:
             "Parameter 'data' must be a pandas dataframe"
 
     @property
-    def _field_mapping(self):
+    def _field_mapping(self, collection_name: str = None):
         '''
         '''
         return self._mapping_parser.get_field_mapping()
@@ -503,6 +523,7 @@ class MigrationCleaning:
 
         return data
 
+    """
     def restrict_to_collection(self, data: pd.DataFrame, collection_name: str) -> pd.DataFrame:
         '''
         '''
@@ -511,6 +532,7 @@ class MigrationCleaning:
         fields = self._mapping_parser.get_fields_restricted_to_collecton(collection_name=collection_name)
 
         return data[[c for c in data.columns if (c in fields) or (c in mongo_fields)]]
+    """
 
 
 if __name__ == "__main__":
@@ -532,8 +554,8 @@ if __name__ == "__main__":
         cleaner = MigrationCleaning(
                 mapping_path=mapping_path,
                 schema_paths=schema_paths,
-                mapping_source="internal_name",
-                mapping_target="mongo_name",
+                mapping_source_name_tag="internal_name",
+                mapping_target_name_tag="mongo_name",
                 filter_index_columns=["radsatznummer"],
                 inconsist_report_table=inconsist_report_table)
 
@@ -562,4 +584,3 @@ if __name__ == "__main__":
         data = cleaner.filter_notallowed_values(data)
 
     print("Done!")
-    

+ 16 - 18
cdplib/db_migration/ParseJsonSchema.py

@@ -49,13 +49,12 @@ class ParseJsonSchema(ParseDbSchema):
         for schema_path in schema_paths:
             try:
                 with open(schema_path, "r") as f:
-                   schema = json.load(f) 
                 # Load schmea dereferenced and cleaned by default values
-                self.schemas.append(self._dereference_schema(schema))
+                self.schemas.append(self._dereference_schema(schema_path))
 
             except Exception as e:
-                err = ("Could not load json schema, "
-                       "Obtained error {}".format(e))
+                err = ("Could not load json schema {0}, "
+                       "Obtained error {1}".format(schema_path, e))
 
                 self._log.error(err)
                 raise Exception(err)
@@ -66,7 +65,7 @@ class ParseJsonSchema(ParseDbSchema):
         '''
         # Don't use strip() instaed of replace since schema_c.strip(schema_)
         # will discard the c as well which is not a appropriate output
-        return [os.path.basename(p).replace("schema_","").split(".")[0] for p in self._schema_paths]
+        return [os.path.basename(p).replace("schema_", "").split(".")[0] for p in self._schema_paths]
 
     def get_fields(self) -> list:
         '''
@@ -331,16 +330,16 @@ class ParseJsonSchema(ParseDbSchema):
 
         return already_parsed
 
-    def _dereference_schema(self, schema: dict) -> dict:
+    def _dereference_schema(self, schema_path: str) -> dict:
         '''
         :param dict schema: dictionary containing a schema which uses references.
         '''
 
-        assert(isinstance(schema, dict)),\
-            "Parameter 'schema' must be a dictionary type"
-            
+        assert(isinstance(schema_path, str)),\
+            "Parameter 'schema_path' must be a string type"
+
         base_dir_url = Path(os.path.join(os.getcwd(), "mongo_schema")).as_uri() + '/'
-        schema = jsonref.loads(str(schema).replace("'", "\""), base_uri=base_dir_url)
+        schema = jsonref.loads(open(schema_path,"r").read(), base_uri=base_dir_url)
         schema = deepcopy(schema)
         schema.pop('definitions', None)
         return schema
@@ -354,11 +353,11 @@ class ParseJsonSchema(ParseDbSchema):
         if 'default_values' in schema:
             del schema['default_values']
         return schema
-    
+
         assert(isinstance(schema, dict)),\
         "Parameter 'schema' must be a dictionary type"
-    
-    # Need to parse schmema for importing to mongo db 
+
+    # Need to parse schmema for importing to mongo db
     # Reason:
     # We need to drop default values since MongoDB can't handle them
     # We need to deference json before import to Mongo DB pymongo can't deal with references
@@ -379,8 +378,8 @@ class ParseJsonSchema(ParseDbSchema):
             schema = self._dereference_schema(schema)
 
         return schema
-           
-   
+
+
     def _analyze_schema(self, schema: dict, definitions_flag: bool = False) -> dict:
 
 
@@ -401,10 +400,10 @@ class ParseJsonSchema(ParseDbSchema):
 
 if __name__ == "__main__":
 
-#     Only for testing
+    # Only for testing
 
     schema_path = os.path.join(".", "mongo_schema", "schema_components.json")
-    
+
     if os.path.isfile(schema_path):
 
         parse_obj = ParseJsonSchema(schema_paths=schema_path)
@@ -424,4 +423,3 @@ if __name__ == "__main__":
         allowed_values = parse_obj.get_allowed_values()
 
         descriptions = parse_obj.get_field_descriptions()
-    

+ 77 - 26
cdplib/db_migration/ParseMapping.py

@@ -15,49 +15,91 @@ class ParseMapping:
     '''
     '''
     def __init__(self, mapping_path: str, log_name: str = "ParseMapping",
-                 source: str = "original_name", target: str = "mongo_name",
-                 target_collection: str = "mongo_collection"):
+                 source_name_tag: str = "original_name", target_name_tag: str = "mongo_name",
+                 target_collection_tag: str = "mongo_collection",
+                 target_collection_name: str = None):
         '''
         '''
         import json
         from cdplib.log import Log
 
-        self.log = Log('Parse Mapping')
+        self._log = Log('Parse Mapping')
 
         if not os.path.isfile(mapping_path):
-            err = "Mapping not found"
-            self._log.error(err)
-            raise FileNotFoundError(err)
-
+            self._log.log_and_raise_error("Mapping not found")
         try:
             with open(mapping_path, "r") as f:
                 self._mapping = json.load(f)
 
         except Exception as e:
-            err = ("Could not load mapping. "
-                   "Exit with error {}".format(e))
-            self._log.error(err)
-            raise Exception(err)
+            self._log.log_and_raise_error("Could not load mapping. "
+                                          "Exit with error {}".format(e))
+
+        self._mapping_path = mapping_path
+        self._source_name_tag = source_name_tag
+        self._target_name_tag = target_name_tag
+        self._target_collection_tag = target_collection_tag
+        self._target_collection_name = target_collection_name
 
-        self._source = source
-        self._target = target
-        self._target_collection = target_collection
+        self._restrict_mapping_to_collection()
 
-    def get_field_mapping(self) -> dict:
+    def _restrict_mapping_to_collection(self):
         '''
         '''
-        assert(all([set([self._source, self._target]) <= set(d)
+        if self._target_collection_name is not None:
+
+            result = []
+
+            for d in self._mapping:
+
+                if self._target_collection_tag not in d:
+                    continue
+
+                for key in [self._target_name_tag, self._target_collection_tag]:
+                    if not isinstance(d[key], list):
+                        d[key] = [d[key]]
+
+                if (len(d[self._target_collection_tag]) > 1) and (len(d[self._target_name_tag]) == 1):
+                    d[self._target_name_tag] = d[self._target_name_tag]*len(d[self._target_collection_tag])
+
+                if len(d[self._target_collection_tag]) != len(d[self._target_name_tag]):
+                    self._log.log_and_raise_error(("In the migration mapping '{0}' {1} "
+                                                   "{2} has an unclear collection")
+                                                   .format(self._mapping_path,
+                                                           self._target_name_tag,
+                                                           d[self._target_name_tag]))
+
+                if self._target_collection_name in d[self._target_collection_tag]:
+
+                    d[self._target_name_tag] = d[self._target_name_tag][d[self._target_collection_tag].index(self._target_collection_name)]
+
+                    d[self._target_collection_tag] = self._target_collection_name
+
+                    result.append(d)
+
+            self._mapping = result
+
+    def get_field_mapping(self, source_name_tag: str = None, target_name_tag: str = None) -> dict:
+        '''
+        '''
+        if source_name_tag is None:
+            source_name_tag = self._source_name_tag
+
+        if target_name_tag is None:
+            target_name_tag = self._target_name_tag
+
+        assert(all([set([source_name_tag, target_name_tag]) <= set(d)
                     for d in self._mapping]))
 
-        return {d[self._source]: d[self._target] for d in self._mapping}
+        return {d[source_name_tag]: d[self._target_name_tag] for d in self._mapping}
 
     def _get_fields_satistisfying_condition(self, key: str, value) -> list:
         '''
         '''
-        assert(all([self._source in d for d in self._mapping])),\
+        assert(all([self._source_name_tag in d for d in self._mapping])),\
             "Invalid from field"
 
-        return [d[self._source] for d in self._mapping
+        return [d[self._source_name_tag] for d in self._mapping
                 if (key in d) and (d[key] == value)]
 
     def get_required_fields(self) -> list:
@@ -72,19 +114,28 @@ class ParseMapping:
         return self._get_fields_satistisfying_condition(key="type",
                                                         value="Date")
 
+    """
     def get_fields_restricted_to_collecton(self, collection_name: str) -> list:
         '''
         '''
-        return self._get_fields_satistisfying_condition(key=self._target_collection,
-                                                        value=collection_name)
+        target_collection_tag_mapping = {d[self._source_name_tag]: d[self._target_collection_tag]
+                                     for d in self._mapping
+                                     if (self._target_collection_tag in d)}
+
+        target_collection_tag_mapping = {k: v if isinstance(v, list) else [v]
+                                     for k, v in target_collection_tag_mapping.items()}
+
+        return [k for k,v in target_collection_tag_mapping.items()
+                if collection_name in v]
+    """
 
     def _get_info(self, key: str, value=None) -> dict:
         '''
         '''
-        assert(all([self._source in d for d in self._mapping])),\
+        assert(all([self._source_name_tag in d for d in self._mapping])),\
             "Invalid from field"
 
-        return {d[self._source]: d[key] for d in self._mapping
+        return {d[self._source_name_tag]: d[key] for d in self._mapping
                 if (key in d) and ((value is not None)
                 and (d[key] == value)) or (key in d)}
 
@@ -134,7 +185,7 @@ class ParseMapping:
         else:
             err = ("Incorrectly filled mapping. Column numbers should ",
                    "either in all or in neither of the fields")
-            self.log.err(err)
+            self._log.err(err)
             raise Exception(err)
 
         return column_numbers
@@ -148,8 +199,8 @@ if __name__ == "__main__":
 
         print("found mapping path")
 
-        parser = ParseMapping(mapping_path, source="internal_name",
-                              target="mongo_name")
+        parser = ParseMapping(mapping_path, source_name_tag="internal_name",
+                              target_name_tag="mongo_name")
 
         internal_to_mongo_mapping = parser.get_field_mapping()