smoothing.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Language Model Unit Tests
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """Smoothing algorithms for language modeling.
  9. According to Chen & Goodman 1995 these should work with both Backoff and
  10. Interpolation.
  11. """
  12. from nltk.lm.api import Smoothing
  13. def _count_non_zero_vals(dictionary):
  14. return sum(1.0 for c in dictionary.values() if c > 0)
  15. class WittenBell(Smoothing):
  16. """Witten-Bell smoothing."""
  17. def __init__(self, vocabulary, counter, discount=0.1, **kwargs):
  18. super(WittenBell, self).__init__(vocabulary, counter, *kwargs)
  19. def alpha_gamma(self, word, context):
  20. gamma = self.gamma(context)
  21. return (1.0 - gamma) * self.alpha(word, context), gamma
  22. def unigram_score(self, word):
  23. return self.counts.unigrams.freq(word)
  24. def alpha(self, word, context):
  25. return self.counts[context].freq(word)
  26. def gamma(self, context):
  27. n_plus = _count_non_zero_vals(self.counts[context])
  28. return n_plus / (n_plus + self.counts[len(context) + 1].N())
  29. class KneserNey(Smoothing):
  30. """Kneser-Ney Smoothing."""
  31. def __init__(self, vocabulary, counter, discount=0.1, **kwargs):
  32. super(KneserNey, self).__init__(vocabulary, counter, *kwargs)
  33. self.discount = discount
  34. def unigram_score(self, word):
  35. return 1.0 / len(self.vocab)
  36. def alpha_gamma(self, word, context):
  37. prefix_counts = self.counts[context]
  38. return self.alpha(word, prefix_counts), self.gamma(prefix_counts)
  39. def alpha(self, word, prefix_counts):
  40. return max(prefix_counts[word] - self.discount, 0.0) / prefix_counts.N()
  41. def gamma(self, prefix_counts):
  42. return self.discount * _count_non_zero_vals(prefix_counts) / prefix_counts.N()