paice.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Natural Language Toolkit: Agreement Metrics
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Lauri Hallila <laurihallila@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. #
  8. """Counts Paice's performance statistics for evaluating stemming algorithms.
  9. What is required:
  10. - A dictionary of words grouped by their real lemmas
  11. - A dictionary of words grouped by stems from a stemming algorithm
  12. When these are given, Understemming Index (UI), Overstemming Index (OI),
  13. Stemming Weight (SW) and Error-rate relative to truncation (ERRT) are counted.
  14. References:
  15. Chris D. Paice (1994). An evaluation method for stemming algorithms.
  16. In Proceedings of SIGIR, 42--50.
  17. """
  18. from math import sqrt
  19. def get_words_from_dictionary(lemmas):
  20. '''
  21. Get original set of words used for analysis.
  22. :param lemmas: A dictionary where keys are lemmas and values are sets
  23. or lists of words corresponding to that lemma.
  24. :type lemmas: dict(str): list(str)
  25. :return: Set of words that exist as values in the dictionary
  26. :rtype: set(str)
  27. '''
  28. words = set()
  29. for lemma in lemmas:
  30. words.update(set(lemmas[lemma]))
  31. return words
  32. def _truncate(words, cutlength):
  33. '''Group words by stems defined by truncating them at given length.
  34. :param words: Set of words used for analysis
  35. :param cutlength: Words are stemmed by cutting at this length.
  36. :type words: set(str) or list(str)
  37. :type cutlength: int
  38. :return: Dictionary where keys are stems and values are sets of words
  39. corresponding to that stem.
  40. :rtype: dict(str): set(str)
  41. '''
  42. stems = {}
  43. for word in words:
  44. stem = word[:cutlength]
  45. try:
  46. stems[stem].update([word])
  47. except KeyError:
  48. stems[stem] = set([word])
  49. return stems
  50. # Reference: http://en.wikipedia.org/wiki/Line-line_intersection
  51. def _count_intersection(l1, l2):
  52. '''Count intersection between two line segments defined by coordinate pairs.
  53. :param l1: Tuple of two coordinate pairs defining the first line segment
  54. :param l2: Tuple of two coordinate pairs defining the second line segment
  55. :type l1: tuple(float, float)
  56. :type l2: tuple(float, float)
  57. :return: Coordinates of the intersection
  58. :rtype: tuple(float, float)
  59. '''
  60. x1, y1 = l1[0]
  61. x2, y2 = l1[1]
  62. x3, y3 = l2[0]
  63. x4, y4 = l2[1]
  64. denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
  65. if denominator == 0.0: # lines are parallel
  66. if x1 == x2 == x3 == x4 == 0.0:
  67. # When lines are parallel, they must be on the y-axis.
  68. # We can ignore x-axis because we stop counting the
  69. # truncation line when we get there.
  70. # There are no other options as UI (x-axis) grows and
  71. # OI (y-axis) diminishes when we go along the truncation line.
  72. return (0.0, y4)
  73. x = (
  74. (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
  75. ) / denominator
  76. y = (
  77. (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
  78. ) / denominator
  79. return (x, y)
  80. def _get_derivative(coordinates):
  81. '''Get derivative of the line from (0,0) to given coordinates.
  82. :param coordinates: A coordinate pair
  83. :type coordinates: tuple(float, float)
  84. :return: Derivative; inf if x is zero
  85. :rtype: float
  86. '''
  87. try:
  88. return coordinates[1] / coordinates[0]
  89. except ZeroDivisionError:
  90. return float('inf')
  91. def _calculate_cut(lemmawords, stems):
  92. '''Count understemmed and overstemmed pairs for (lemma, stem) pair with common words.
  93. :param lemmawords: Set or list of words corresponding to certain lemma.
  94. :param stems: A dictionary where keys are stems and values are sets
  95. or lists of words corresponding to that stem.
  96. :type lemmawords: set(str) or list(str)
  97. :type stems: dict(str): set(str)
  98. :return: Amount of understemmed and overstemmed pairs contributed by words
  99. existing in both lemmawords and stems.
  100. :rtype: tuple(float, float)
  101. '''
  102. umt, wmt = 0.0, 0.0
  103. for stem in stems:
  104. cut = set(lemmawords) & set(stems[stem])
  105. if cut:
  106. cutcount = len(cut)
  107. stemcount = len(stems[stem])
  108. # Unachieved merge total
  109. umt += cutcount * (len(lemmawords) - cutcount)
  110. # Wrongly merged total
  111. wmt += cutcount * (stemcount - cutcount)
  112. return (umt, wmt)
  113. def _calculate(lemmas, stems):
  114. '''Calculate actual and maximum possible amounts of understemmed and overstemmed word pairs.
  115. :param lemmas: A dictionary where keys are lemmas and values are sets
  116. or lists of words corresponding to that lemma.
  117. :param stems: A dictionary where keys are stems and values are sets
  118. or lists of words corresponding to that stem.
  119. :type lemmas: dict(str): list(str)
  120. :type stems: dict(str): set(str)
  121. :return: Global unachieved merge total (gumt),
  122. global desired merge total (gdmt),
  123. global wrongly merged total (gwmt) and
  124. global desired non-merge total (gdnt).
  125. :rtype: tuple(float, float, float, float)
  126. '''
  127. n = sum(len(lemmas[word]) for word in lemmas)
  128. gdmt, gdnt, gumt, gwmt = (0.0, 0.0, 0.0, 0.0)
  129. for lemma in lemmas:
  130. lemmacount = len(lemmas[lemma])
  131. # Desired merge total
  132. gdmt += lemmacount * (lemmacount - 1)
  133. # Desired non-merge total
  134. gdnt += lemmacount * (n - lemmacount)
  135. # For each (lemma, stem) pair with common words, count how many
  136. # pairs are understemmed and overstemmed.
  137. umt, wmt = _calculate_cut(lemmas[lemma], stems)
  138. # Add to total undesired and wrongly-merged totals
  139. gumt += umt
  140. gwmt += wmt
  141. # Each object is counted twice, so divide by two
  142. return (gumt / 2, gdmt / 2, gwmt / 2, gdnt / 2)
  143. def _indexes(gumt, gdmt, gwmt, gdnt):
  144. '''Count Understemming Index (UI), Overstemming Index (OI) and Stemming Weight (SW).
  145. :param gumt, gdmt, gwmt, gdnt: Global unachieved merge total (gumt),
  146. global desired merge total (gdmt),
  147. global wrongly merged total (gwmt) and
  148. global desired non-merge total (gdnt).
  149. :type gumt, gdmt, gwmt, gdnt: float
  150. :return: Understemming Index (UI),
  151. Overstemming Index (OI) and
  152. Stemming Weight (SW).
  153. :rtype: tuple(float, float, float)
  154. '''
  155. # Calculate Understemming Index (UI),
  156. # Overstemming Index (OI) and Stemming Weight (SW)
  157. try:
  158. ui = gumt / gdmt
  159. except ZeroDivisionError:
  160. # If GDMT (max merge total) is 0, define UI as 0
  161. ui = 0.0
  162. try:
  163. oi = gwmt / gdnt
  164. except ZeroDivisionError:
  165. # IF GDNT (max non-merge total) is 0, define OI as 0
  166. oi = 0.0
  167. try:
  168. sw = oi / ui
  169. except ZeroDivisionError:
  170. if oi == 0.0:
  171. # OI and UI are 0, define SW as 'not a number'
  172. sw = float('nan')
  173. else:
  174. # UI is 0, define SW as infinity
  175. sw = float('inf')
  176. return (ui, oi, sw)
  177. class Paice(object):
  178. '''Class for storing lemmas, stems and evaluation metrics.'''
  179. def __init__(self, lemmas, stems):
  180. '''
  181. :param lemmas: A dictionary where keys are lemmas and values are sets
  182. or lists of words corresponding to that lemma.
  183. :param stems: A dictionary where keys are stems and values are sets
  184. or lists of words corresponding to that stem.
  185. :type lemmas: dict(str): list(str)
  186. :type stems: dict(str): set(str)
  187. '''
  188. self.lemmas = lemmas
  189. self.stems = stems
  190. self.coords = []
  191. self.gumt, self.gdmt, self.gwmt, self.gdnt = (None, None, None, None)
  192. self.ui, self.oi, self.sw = (None, None, None)
  193. self.errt = None
  194. self.update()
  195. def __str__(self):
  196. text = ['Global Unachieved Merge Total (GUMT): %s\n' % self.gumt]
  197. text.append('Global Desired Merge Total (GDMT): %s\n' % self.gdmt)
  198. text.append('Global Wrongly-Merged Total (GWMT): %s\n' % self.gwmt)
  199. text.append('Global Desired Non-merge Total (GDNT): %s\n' % self.gdnt)
  200. text.append('Understemming Index (GUMT / GDMT): %s\n' % self.ui)
  201. text.append('Overstemming Index (GWMT / GDNT): %s\n' % self.oi)
  202. text.append('Stemming Weight (OI / UI): %s\n' % self.sw)
  203. text.append('Error-Rate Relative to Truncation (ERRT): %s\r\n' % self.errt)
  204. coordinates = ' '.join(['(%s, %s)' % item for item in self.coords])
  205. text.append('Truncation line: %s' % coordinates)
  206. return ''.join(text)
  207. def _get_truncation_indexes(self, words, cutlength):
  208. '''Count (UI, OI) when stemming is done by truncating words at \'cutlength\'.
  209. :param words: Words used for the analysis
  210. :param cutlength: Words are stemmed by cutting them at this length
  211. :type words: set(str) or list(str)
  212. :type cutlength: int
  213. :return: Understemming and overstemming indexes
  214. :rtype: tuple(int, int)
  215. '''
  216. truncated = _truncate(words, cutlength)
  217. gumt, gdmt, gwmt, gdnt = _calculate(self.lemmas, truncated)
  218. ui, oi = _indexes(gumt, gdmt, gwmt, gdnt)[:2]
  219. return (ui, oi)
  220. def _get_truncation_coordinates(self, cutlength=0):
  221. '''Count (UI, OI) pairs for truncation points until we find the segment where (ui, oi) crosses the truncation line.
  222. :param cutlength: Optional parameter to start counting from (ui, oi)
  223. coordinates gotten by stemming at this length. Useful for speeding up
  224. the calculations when you know the approximate location of the
  225. intersection.
  226. :type cutlength: int
  227. :return: List of coordinate pairs that define the truncation line
  228. :rtype: list(tuple(float, float))
  229. '''
  230. words = get_words_from_dictionary(self.lemmas)
  231. maxlength = max(len(word) for word in words)
  232. # Truncate words from different points until (0, 0) - (ui, oi) segment crosses the truncation line
  233. coords = []
  234. while cutlength <= maxlength:
  235. # Get (UI, OI) pair of current truncation point
  236. pair = self._get_truncation_indexes(words, cutlength)
  237. # Store only new coordinates so we'll have an actual
  238. # line segment when counting the intersection point
  239. if pair not in coords:
  240. coords.append(pair)
  241. if pair == (0.0, 0.0):
  242. # Stop counting if truncation line goes through origo;
  243. # length from origo to truncation line is 0
  244. return coords
  245. if len(coords) >= 2 and pair[0] > 0.0:
  246. derivative1 = _get_derivative(coords[-2])
  247. derivative2 = _get_derivative(coords[-1])
  248. # Derivative of the truncation line is a decreasing value;
  249. # when it passes Stemming Weight, we've found the segment
  250. # of truncation line intersecting with (0, 0) - (ui, oi) segment
  251. if derivative1 >= self.sw >= derivative2:
  252. return coords
  253. cutlength += 1
  254. return coords
  255. def _errt(self):
  256. '''Count Error-Rate Relative to Truncation (ERRT).
  257. :return: ERRT, length of the line from origo to (UI, OI) divided by
  258. the length of the line from origo to the point defined by the same
  259. line when extended until the truncation line.
  260. :rtype: float
  261. '''
  262. # Count (UI, OI) pairs for truncation points until we find the segment where (ui, oi) crosses the truncation line
  263. self.coords = self._get_truncation_coordinates()
  264. if (0.0, 0.0) in self.coords:
  265. # Truncation line goes through origo, so ERRT cannot be counted
  266. if (self.ui, self.oi) != (0.0, 0.0):
  267. return float('inf')
  268. else:
  269. return float('nan')
  270. if (self.ui, self.oi) == (0.0, 0.0):
  271. # (ui, oi) is origo; define errt as 0.0
  272. return 0.0
  273. # Count the intersection point
  274. # Note that (self.ui, self.oi) cannot be (0.0, 0.0) and self.coords has different coordinates
  275. # so we have actual line segments instead of a line segment and a point
  276. intersection = _count_intersection(
  277. ((0, 0), (self.ui, self.oi)), self.coords[-2:]
  278. )
  279. # Count OP (length of the line from origo to (ui, oi))
  280. op = sqrt(self.ui ** 2 + self.oi ** 2)
  281. # Count OT (length of the line from origo to truncation line that goes through (ui, oi))
  282. ot = sqrt(intersection[0] ** 2 + intersection[1] ** 2)
  283. # OP / OT tells how well the stemming algorithm works compared to just truncating words
  284. return op / ot
  285. def update(self):
  286. '''Update statistics after lemmas and stems have been set.'''
  287. self.gumt, self.gdmt, self.gwmt, self.gdnt = _calculate(self.lemmas, self.stems)
  288. self.ui, self.oi, self.sw = _indexes(self.gumt, self.gdmt, self.gwmt, self.gdnt)
  289. self.errt = self._errt()
  290. def demo():
  291. '''Demonstration of the module.'''
  292. # Some words with their real lemmas
  293. lemmas = {
  294. 'kneel': ['kneel', 'knelt'],
  295. 'range': ['range', 'ranged'],
  296. 'ring': ['ring', 'rang', 'rung'],
  297. }
  298. # Same words with stems from a stemming algorithm
  299. stems = {
  300. 'kneel': ['kneel'],
  301. 'knelt': ['knelt'],
  302. 'rang': ['rang', 'range', 'ranged'],
  303. 'ring': ['ring'],
  304. 'rung': ['rung'],
  305. }
  306. print('Words grouped by their lemmas:')
  307. for lemma in sorted(lemmas):
  308. print('%s => %s' % (lemma, ' '.join(lemmas[lemma])))
  309. print()
  310. print('Same words grouped by a stemming algorithm:')
  311. for stem in sorted(stems):
  312. print('%s => %s' % (stem, ' '.join(stems[stem])))
  313. print()
  314. p = Paice(lemmas, stems)
  315. print(p)
  316. print()
  317. # Let's "change" results from a stemming algorithm
  318. stems = {
  319. 'kneel': ['kneel'],
  320. 'knelt': ['knelt'],
  321. 'rang': ['rang'],
  322. 'range': ['range', 'ranged'],
  323. 'ring': ['ring'],
  324. 'rung': ['rung'],
  325. }
  326. print('Counting stats after changing stemming results:')
  327. for stem in sorted(stems):
  328. print('%s => %s' % (stem, ' '.join(stems[stem])))
  329. print()
  330. p.stems = stems
  331. p.update()
  332. print(p)
  333. if __name__ == '__main__':
  334. demo()