ibm3.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: IBM Model 3
  3. #
  4. # Copyright (C) 2001-2013 NLTK Project
  5. # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """
  9. Translation model that considers how a word can be aligned to
  10. multiple words in another language.
  11. IBM Model 3 improves on Model 2 by directly modeling the phenomenon
  12. where a word in one language may be translated into zero or more words
  13. in another. This is expressed by the fertility probability,
  14. n(phi | source word).
  15. If a source word translates into more than one word, it is possible to
  16. generate sentences that have the same alignment in multiple ways. This
  17. is modeled by a distortion step. The distortion probability, d(j|i,l,m),
  18. predicts a target word position, given its aligned source word's
  19. position. The distortion probability replaces the alignment probability
  20. of Model 2.
  21. The fertility probability is not applicable for NULL. Target words that
  22. align to NULL are assumed to be distributed uniformly in the target
  23. sentence. The existence of these words is modeled by p1, the probability
  24. that a target word produced by a real source word requires another
  25. target word that is produced by NULL.
  26. The EM algorithm used in Model 3 is:
  27. E step - In the training data, collect counts, weighted by prior
  28. probabilities.
  29. (a) count how many times a source language word is translated
  30. into a target language word
  31. (b) count how many times a particular position in the target
  32. sentence is aligned to a particular position in the source
  33. sentence
  34. (c) count how many times a source word is aligned to phi number
  35. of target words
  36. (d) count how many times NULL is aligned to a target word
  37. M step - Estimate new probabilities based on the counts from the E step
  38. Because there are too many possible alignments, only the most probable
  39. ones are considered. First, the best alignment is determined using prior
  40. probabilities. Then, a hill climbing approach is used to find other good
  41. candidates.
  42. Notations:
  43. i: Position in the source sentence
  44. Valid values are 0 (for NULL), 1, 2, ..., length of source sentence
  45. j: Position in the target sentence
  46. Valid values are 1, 2, ..., length of target sentence
  47. l: Number of words in the source sentence, excluding NULL
  48. m: Number of words in the target sentence
  49. s: A word in the source language
  50. t: A word in the target language
  51. phi: Fertility, the number of target words produced by a source word
  52. p1: Probability that a target word produced by a source word is
  53. accompanied by another target word that is aligned to NULL
  54. p0: 1 - p1
  55. References:
  56. Philipp Koehn. 2010. Statistical Machine Translation.
  57. Cambridge University Press, New York.
  58. Peter E Brown, Stephen A. Della Pietra, Vincent J. Della Pietra, and
  59. Robert L. Mercer. 1993. The Mathematics of Statistical Machine
  60. Translation: Parameter Estimation. Computational Linguistics, 19 (2),
  61. 263-311.
  62. """
  63. from __future__ import division
  64. import warnings
  65. from collections import defaultdict
  66. from math import factorial
  67. from nltk.translate import AlignedSent
  68. from nltk.translate import Alignment
  69. from nltk.translate import IBMModel
  70. from nltk.translate import IBMModel2
  71. from nltk.translate.ibm_model import Counts
  72. class IBMModel3(IBMModel):
  73. """
  74. Translation model that considers how a word can be aligned to
  75. multiple words in another language
  76. >>> bitext = []
  77. >>> bitext.append(AlignedSent(['klein', 'ist', 'das', 'haus'], ['the', 'house', 'is', 'small']))
  78. >>> bitext.append(AlignedSent(['das', 'haus', 'war', 'ja', 'groß'], ['the', 'house', 'was', 'big']))
  79. >>> bitext.append(AlignedSent(['das', 'buch', 'ist', 'ja', 'klein'], ['the', 'book', 'is', 'small']))
  80. >>> bitext.append(AlignedSent(['ein', 'haus', 'ist', 'klein'], ['a', 'house', 'is', 'small']))
  81. >>> bitext.append(AlignedSent(['das', 'haus'], ['the', 'house']))
  82. >>> bitext.append(AlignedSent(['das', 'buch'], ['the', 'book']))
  83. >>> bitext.append(AlignedSent(['ein', 'buch'], ['a', 'book']))
  84. >>> bitext.append(AlignedSent(['ich', 'fasse', 'das', 'buch', 'zusammen'], ['i', 'summarize', 'the', 'book']))
  85. >>> bitext.append(AlignedSent(['fasse', 'zusammen'], ['summarize']))
  86. >>> ibm3 = IBMModel3(bitext, 5)
  87. >>> print(round(ibm3.translation_table['buch']['book'], 3))
  88. 1.0
  89. >>> print(round(ibm3.translation_table['das']['book'], 3))
  90. 0.0
  91. >>> print(round(ibm3.translation_table['ja'][None], 3))
  92. 1.0
  93. >>> print(round(ibm3.distortion_table[1][1][2][2], 3))
  94. 1.0
  95. >>> print(round(ibm3.distortion_table[1][2][2][2], 3))
  96. 0.0
  97. >>> print(round(ibm3.distortion_table[2][2][4][5], 3))
  98. 0.75
  99. >>> print(round(ibm3.fertility_table[2]['summarize'], 3))
  100. 1.0
  101. >>> print(round(ibm3.fertility_table[1]['book'], 3))
  102. 1.0
  103. >>> print(ibm3.p1)
  104. 0.054...
  105. >>> test_sentence = bitext[2]
  106. >>> test_sentence.words
  107. ['das', 'buch', 'ist', 'ja', 'klein']
  108. >>> test_sentence.mots
  109. ['the', 'book', 'is', 'small']
  110. >>> test_sentence.alignment
  111. Alignment([(0, 0), (1, 1), (2, 2), (3, None), (4, 3)])
  112. """
  113. def __init__(self, sentence_aligned_corpus, iterations, probability_tables=None):
  114. """
  115. Train on ``sentence_aligned_corpus`` and create a lexical
  116. translation model, a distortion model, a fertility model, and a
  117. model for generating NULL-aligned words.
  118. Translation direction is from ``AlignedSent.mots`` to
  119. ``AlignedSent.words``.
  120. :param sentence_aligned_corpus: Sentence-aligned parallel corpus
  121. :type sentence_aligned_corpus: list(AlignedSent)
  122. :param iterations: Number of iterations to run training algorithm
  123. :type iterations: int
  124. :param probability_tables: Optional. Use this to pass in custom
  125. probability values. If not specified, probabilities will be
  126. set to a uniform distribution, or some other sensible value.
  127. If specified, all the following entries must be present:
  128. ``translation_table``, ``alignment_table``,
  129. ``fertility_table``, ``p1``, ``distortion_table``.
  130. See ``IBMModel`` for the type and purpose of these tables.
  131. :type probability_tables: dict[str]: object
  132. """
  133. super(IBMModel3, self).__init__(sentence_aligned_corpus)
  134. self.reset_probabilities()
  135. if probability_tables is None:
  136. # Get translation and alignment probabilities from IBM Model 2
  137. ibm2 = IBMModel2(sentence_aligned_corpus, iterations)
  138. self.translation_table = ibm2.translation_table
  139. self.alignment_table = ibm2.alignment_table
  140. self.set_uniform_probabilities(sentence_aligned_corpus)
  141. else:
  142. # Set user-defined probabilities
  143. self.translation_table = probability_tables['translation_table']
  144. self.alignment_table = probability_tables['alignment_table']
  145. self.fertility_table = probability_tables['fertility_table']
  146. self.p1 = probability_tables['p1']
  147. self.distortion_table = probability_tables['distortion_table']
  148. for n in range(0, iterations):
  149. self.train(sentence_aligned_corpus)
  150. def reset_probabilities(self):
  151. super(IBMModel3, self).reset_probabilities()
  152. self.distortion_table = defaultdict(
  153. lambda: defaultdict(
  154. lambda: defaultdict(lambda: defaultdict(lambda: self.MIN_PROB))
  155. )
  156. )
  157. """
  158. dict[int][int][int][int]: float. Probability(j | i,l,m).
  159. Values accessed as ``distortion_table[j][i][l][m]``.
  160. """
  161. def set_uniform_probabilities(self, sentence_aligned_corpus):
  162. # d(j | i,l,m) = 1 / m for all i, j, l, m
  163. l_m_combinations = set()
  164. for aligned_sentence in sentence_aligned_corpus:
  165. l = len(aligned_sentence.mots)
  166. m = len(aligned_sentence.words)
  167. if (l, m) not in l_m_combinations:
  168. l_m_combinations.add((l, m))
  169. initial_prob = 1 / m
  170. if initial_prob < IBMModel.MIN_PROB:
  171. warnings.warn(
  172. "A target sentence is too long ("
  173. + str(m)
  174. + " words). Results may be less accurate."
  175. )
  176. for j in range(1, m + 1):
  177. for i in range(0, l + 1):
  178. self.distortion_table[j][i][l][m] = initial_prob
  179. # simple initialization, taken from GIZA++
  180. self.fertility_table[0] = defaultdict(lambda: 0.2)
  181. self.fertility_table[1] = defaultdict(lambda: 0.65)
  182. self.fertility_table[2] = defaultdict(lambda: 0.1)
  183. self.fertility_table[3] = defaultdict(lambda: 0.04)
  184. MAX_FERTILITY = 10
  185. initial_fert_prob = 0.01 / (MAX_FERTILITY - 4)
  186. for phi in range(4, MAX_FERTILITY):
  187. self.fertility_table[phi] = defaultdict(lambda: initial_fert_prob)
  188. self.p1 = 0.5
  189. def train(self, parallel_corpus):
  190. counts = Model3Counts()
  191. for aligned_sentence in parallel_corpus:
  192. l = len(aligned_sentence.mots)
  193. m = len(aligned_sentence.words)
  194. # Sample the alignment space
  195. sampled_alignments, best_alignment = self.sample(aligned_sentence)
  196. # Record the most probable alignment
  197. aligned_sentence.alignment = Alignment(
  198. best_alignment.zero_indexed_alignment()
  199. )
  200. # E step (a): Compute normalization factors to weigh counts
  201. total_count = self.prob_of_alignments(sampled_alignments)
  202. # E step (b): Collect counts
  203. for alignment_info in sampled_alignments:
  204. count = self.prob_t_a_given_s(alignment_info)
  205. normalized_count = count / total_count
  206. for j in range(1, m + 1):
  207. counts.update_lexical_translation(
  208. normalized_count, alignment_info, j
  209. )
  210. counts.update_distortion(normalized_count, alignment_info, j, l, m)
  211. counts.update_null_generation(normalized_count, alignment_info)
  212. counts.update_fertility(normalized_count, alignment_info)
  213. # M step: Update probabilities with maximum likelihood estimates
  214. # If any probability is less than MIN_PROB, clamp it to MIN_PROB
  215. existing_alignment_table = self.alignment_table
  216. self.reset_probabilities()
  217. self.alignment_table = existing_alignment_table # don't retrain
  218. self.maximize_lexical_translation_probabilities(counts)
  219. self.maximize_distortion_probabilities(counts)
  220. self.maximize_fertility_probabilities(counts)
  221. self.maximize_null_generation_probabilities(counts)
  222. def maximize_distortion_probabilities(self, counts):
  223. MIN_PROB = IBMModel.MIN_PROB
  224. for j, i_s in counts.distortion.items():
  225. for i, src_sentence_lengths in i_s.items():
  226. for l, trg_sentence_lengths in src_sentence_lengths.items():
  227. for m in trg_sentence_lengths:
  228. estimate = (
  229. counts.distortion[j][i][l][m]
  230. / counts.distortion_for_any_j[i][l][m]
  231. )
  232. self.distortion_table[j][i][l][m] = max(estimate, MIN_PROB)
  233. def prob_t_a_given_s(self, alignment_info):
  234. """
  235. Probability of target sentence and an alignment given the
  236. source sentence
  237. """
  238. src_sentence = alignment_info.src_sentence
  239. trg_sentence = alignment_info.trg_sentence
  240. l = len(src_sentence) - 1 # exclude NULL
  241. m = len(trg_sentence) - 1
  242. p1 = self.p1
  243. p0 = 1 - p1
  244. probability = 1.0
  245. MIN_PROB = IBMModel.MIN_PROB
  246. # Combine NULL insertion probability
  247. null_fertility = alignment_info.fertility_of_i(0)
  248. probability *= pow(p1, null_fertility) * pow(p0, m - 2 * null_fertility)
  249. if probability < MIN_PROB:
  250. return MIN_PROB
  251. # Compute combination (m - null_fertility) choose null_fertility
  252. for i in range(1, null_fertility + 1):
  253. probability *= (m - null_fertility - i + 1) / i
  254. if probability < MIN_PROB:
  255. return MIN_PROB
  256. # Combine fertility probabilities
  257. for i in range(1, l + 1):
  258. fertility = alignment_info.fertility_of_i(i)
  259. probability *= (
  260. factorial(fertility) * self.fertility_table[fertility][src_sentence[i]]
  261. )
  262. if probability < MIN_PROB:
  263. return MIN_PROB
  264. # Combine lexical and distortion probabilities
  265. for j in range(1, m + 1):
  266. t = trg_sentence[j]
  267. i = alignment_info.alignment[j]
  268. s = src_sentence[i]
  269. probability *= (
  270. self.translation_table[t][s] * self.distortion_table[j][i][l][m]
  271. )
  272. if probability < MIN_PROB:
  273. return MIN_PROB
  274. return probability
  275. class Model3Counts(Counts):
  276. """
  277. Data object to store counts of various parameters during training.
  278. Includes counts for distortion.
  279. """
  280. def __init__(self):
  281. super(Model3Counts, self).__init__()
  282. self.distortion = defaultdict(
  283. lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0)))
  284. )
  285. self.distortion_for_any_j = defaultdict(
  286. lambda: defaultdict(lambda: defaultdict(lambda: 0.0))
  287. )
  288. def update_distortion(self, count, alignment_info, j, l, m):
  289. i = alignment_info.alignment[j]
  290. self.distortion[j][i][l][m] += count
  291. self.distortion_for_any_j[i][l][m] += count