testStatisticalFeatures.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 18 16:26:47 2018
  5. @author: tanya
  6. """
  7. import os
  8. import unittest
  9. import logging
  10. import pandas as pd
  11. import numpy as np
  12. from pandas.util.testing import assert_frame_equal
  13. from libraries.feature_engineering.in_memory_feature_engineering.StatisticalFeatures import StatisticalFeatures
  14. from libraries.logging.logging_utils import configure_logging
  15. class TestStatisticalFeatures(unittest.TestCase):
  16. '''
  17. '''
  18. def __init__(self, data = None, index_cols = None, path_to_log = None):
  19. '''
  20. '''
  21. if index_cols is None:
  22. self.index_cols = ['id1', 'id2']
  23. else:
  24. self.index_cols = index_cols
  25. if data is None:
  26. self.data = pd.DataFrame({'int' : [1,2,3,2,55,3,7],
  27. 'float' : [0.1, 7, 0.1, 99.9, 99.9, np.nan, 7],
  28. 'str' : ['a', np.nan, 'c', 'a', 'a', '', 'c'],
  29. 'datetime' : [pd.datetime(2017, 1, 2), np.nan, pd.datetime(2017, 5, 3), pd.datetime(2017, 1, 4),
  30. '2018-01-19', pd.datetime(2018, 1, 4), pd.datetime(2019, 3, 23)],
  31. 'nan' : [np.nan]*7,
  32. 'id1' : [1,1,3,3,3,1,1],
  33. 'id2' : ['a', 'a', 'b', 'b', 'a', 'a', np.nan]})\
  34. .sort_values(by = self.index_cols)
  35. else:
  36. self.data = data
  37. self.obj = StatisticalFeatures(data = self.data, index_cols = self.index_cols, path_to_log = path_to_log)
  38. class TestKpisByAggregation(TestStatisticalFeatures):
  39. '''
  40. '''
  41. def __init__(self, data = None, index_cols = None, path_to_log = None):
  42. '''
  43. '''
  44. super(TestKpisByAggregation, self).__init__(data = data, index_cols = index_cols, path_to_log = path_to_log)
  45. def test_builtin_aggfuncs_numeric_cols(self, answer = None, kpis = None):
  46. '''Tests the expected behaviour of pandas builtin aggregation function,
  47. in particular behaviour with missing values
  48. :param DataFrame data:
  49. :param list index_cols:
  50. :param DataFrame answer:
  51. :param list of tuples or dict kpis:
  52. '''
  53. kpis = kpis or [('int', ['min', 'std']),
  54. ('float', ['mean', np.sum]),
  55. ('float', 'sum'),
  56. ('nan', 'mean')]
  57. answer = answer or pd.DataFrame([
  58. {'id1' : 1, 'id2' : 'a', 'int_min' : 1, 'int_std' : pd.Series([1,2,3]).std(), 'float_mean' : np.mean([0.1, 7.0]), 'float_sum' : 7.1, 'nan_mean' : np.nan},
  59. {'id1' : 3, 'id2' : 'b', 'int_min' : 2, 'int_std' : pd.Series([2,3]).std(), 'float_mean' : np.mean([0.1, 99.9]), 'float_sum' : 100, 'nan_mean' : np.nan},
  60. {'id1' : 3, 'id2' : 'a', 'int_min' : 55, 'int_std' : np.nan, 'float_mean' : 99.9, 'float_sum' : 99.9, 'nan_mean' : np.nan},
  61. ]).sort_values(self.index_cols).set_index(self.index_cols)
  62. result = self.obj.get_kpis_by_aggregation(kpis = kpis).sort_values(self.index_cols).set_index(self.index_cols)
  63. assert_frame_equal(result, answer[result.columns])
  64. def test_dict_kpi(self, kpis = None, answer = None):
  65. '''
  66. '''
  67. kpis = kpis or {'int' : ['min', 'std'], 'float' : 'mean'}
  68. answer = answer or pd.DataFrame([
  69. {'id1' : 1, 'id2' : 'a', 'int_min' : 1, 'int_std' : pd.Series([1,2,3]).std(), 'float_mean' : np.mean([0.1, 7.0])},
  70. {'id1' : 3, 'id2' : 'b', 'int_min' : 2, 'int_std' : pd.Series([2,3]).std(), 'float_mean' : np.mean([0.1, 99.9])},
  71. {'id1' : 3, 'id2' : 'a', 'int_min' : 55, 'int_std' : np.nan, 'float_mean' : 99.9},
  72. ]).sort_values(self.index_cols).set_index(self.index_cols)
  73. result = self.obj.get_kpis_by_aggregation(kpis = kpis).sort_values(self.index_cols).set_index(self.index_cols)
  74. assert_frame_equal(result, answer[result.columns])
  75. def test_string_cols(self, kpis = None, answer = None):
  76. '''
  77. '''
  78. kpis = kpis or {'str' : ['sum']}
  79. answer = answer or pd.DataFrame([
  80. {'id1' : 1, 'id2' : 'a', 'str_sum' : 'anan'},
  81. {'id1' : 3, 'id2' : 'b', 'str_sum' : 'ca'},
  82. {'id1' : 3, 'id2' : 'a', 'str_sum' : 'a'},
  83. ]).sort_values(self.index_cols).set_index(self.index_cols)
  84. result = self.obj.get_kpis_by_aggregation(kpis = kpis).sort_values(self.index_cols).set_index(self.index_cols)
  85. assert_frame_equal(result, answer[result.columns])
  86. def test_custom_aggfunc(self, kpis, answer = None):
  87. '''
  88. '''
  89. if kpis is None:
  90. def custom_sum(x):
  91. return np.sum(x)
  92. kpis = {'int' : custom_sum}
  93. answer = answer or pd.DataFrame([
  94. {'id1' : 1, 'id2' : 'a', 'int_custom_sum' : 6},
  95. {'id1' : 3, 'id2' : 'b', 'int_custom_sum' : 55},
  96. {'id1' : 3, 'id2' : 'a', 'int_custom_sum' : 5},
  97. ]).sort_values(self.index_cols).set_index(self.index_cols)
  98. result = self.obj.get_kpis_by_aggregation(kpis = kpis).sort_values(self.index_cols).set_index(self.index_cols)
  99. assert_frame_equal(result, answer[result.columns])
  100. def test_some_wrong_col(self, kpis = None, answer = None):
  101. '''
  102. '''
  103. kpis = kpis or {'bla' : 'sum', 'int' : 'sum'}
  104. answer = answer or pd.DataFrame([
  105. {'id1' : 1, 'id2' : 'a', 'int_sum' : 6},
  106. {'id1' : 3, 'id2' : 'a', 'int_sum' : 55},
  107. {'id1' : 3, 'id2' : 'b', 'int_sum' : 5},
  108. ]).sort_values(self.index_cols).set_index(self.index_cols)
  109. result = self.obj.get_kpis_by_aggregation(kpis = kpis).sort_values(self.index_cols).set_index(self.index_cols)
  110. assert_frame_equal(result, answer[result.columns])
  111. def test_all_wrong_cols(self, kpis = None, answer = None):
  112. '''
  113. '''
  114. kpis = kpis or {'bla' : 'sum', 'blub' : 'sum'}
  115. result = self.obj.get_kpis_by_aggregation(kpis = kpis)
  116. answer = self.data[self.index_cols].drop_duplicates().reset_index(drop = True)
  117. assert_frame_equal(result, answer[result.columns])
  118. if __name__ == '__main__':
  119. path_to_log = os.path.join(os.environ.get('PROJECT_DIR'),
  120. 'tests', 'test_feature_engineering','test_in_memory_feature_engineering',
  121. 'test_kpis_by_aggregation.log')
  122. configure_logging(path_to_log)
  123. logger = logging.getLogger(__name__)
  124. inst = TestKpisByAggregation(path_to_log = path_to_log)
  125. inst.test_builtin_aggfuncs_numeric_cols()
  126. inst.test_dict_kpi()
  127. inst.test_string_cols()
  128. inst.test_some_wrong_col()
  129. inst.test_all_wrong_cols()
  130. logger.info('Done testing method get_kpis_by_aggregation!')