megam.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Natural Language Toolkit: Interface to Megam Classifier
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. A set of functions used to interface with the external megam_ maxent
  9. optimization package. Before megam can be used, you should tell NLTK where it
  10. can find the megam binary, using the ``config_megam()`` function. Typical
  11. usage:
  12. >>> from nltk.classify import megam
  13. >>> megam.config_megam() # pass path to megam if not found in PATH # doctest: +SKIP
  14. [Found megam: ...]
  15. Use with MaxentClassifier. Example below, see MaxentClassifier documentation
  16. for details.
  17. nltk.classify.MaxentClassifier.train(corpus, 'megam')
  18. .. _megam: http://www.umiacs.umd.edu/~hal/megam/index.html
  19. """
  20. from __future__ import print_function
  21. import subprocess
  22. from six import string_types
  23. from nltk import compat
  24. from nltk.internals import find_binary
  25. try:
  26. import numpy
  27. except ImportError:
  28. numpy = None
  29. ######################################################################
  30. # { Configuration
  31. ######################################################################
  32. _megam_bin = None
  33. def config_megam(bin=None):
  34. """
  35. Configure NLTK's interface to the ``megam`` maxent optimization
  36. package.
  37. :param bin: The full path to the ``megam`` binary. If not specified,
  38. then nltk will search the system for a ``megam`` binary; and if
  39. one is not found, it will raise a ``LookupError`` exception.
  40. :type bin: str
  41. """
  42. global _megam_bin
  43. _megam_bin = find_binary(
  44. 'megam',
  45. bin,
  46. env_vars=['MEGAM'],
  47. binary_names=['megam.opt', 'megam', 'megam_686', 'megam_i686.opt'],
  48. url='http://www.umiacs.umd.edu/~hal/megam/index.html',
  49. )
  50. ######################################################################
  51. # { Megam Interface Functions
  52. ######################################################################
  53. def write_megam_file(train_toks, encoding, stream, bernoulli=True, explicit=True):
  54. """
  55. Generate an input file for ``megam`` based on the given corpus of
  56. classified tokens.
  57. :type train_toks: list(tuple(dict, str))
  58. :param train_toks: Training data, represented as a list of
  59. pairs, the first member of which is a feature dictionary,
  60. and the second of which is a classification label.
  61. :type encoding: MaxentFeatureEncodingI
  62. :param encoding: A feature encoding, used to convert featuresets
  63. into feature vectors. May optionally implement a cost() method
  64. in order to assign different costs to different class predictions.
  65. :type stream: stream
  66. :param stream: The stream to which the megam input file should be
  67. written.
  68. :param bernoulli: If true, then use the 'bernoulli' format. I.e.,
  69. all joint features have binary values, and are listed iff they
  70. are true. Otherwise, list feature values explicitly. If
  71. ``bernoulli=False``, then you must call ``megam`` with the
  72. ``-fvals`` option.
  73. :param explicit: If true, then use the 'explicit' format. I.e.,
  74. list the features that would fire for any of the possible
  75. labels, for each token. If ``explicit=True``, then you must
  76. call ``megam`` with the ``-explicit`` option.
  77. """
  78. # Look up the set of labels.
  79. labels = encoding.labels()
  80. labelnum = dict((label, i) for (i, label) in enumerate(labels))
  81. # Write the file, which contains one line per instance.
  82. for featureset, label in train_toks:
  83. # First, the instance number (or, in the weighted multiclass case, the cost of each label).
  84. if hasattr(encoding, 'cost'):
  85. stream.write(
  86. ':'.join(str(encoding.cost(featureset, label, l)) for l in labels)
  87. )
  88. else:
  89. stream.write('%d' % labelnum[label])
  90. # For implicit file formats, just list the features that fire
  91. # for this instance's actual label.
  92. if not explicit:
  93. _write_megam_features(encoding.encode(featureset, label), stream, bernoulli)
  94. # For explicit formats, list the features that would fire for
  95. # any of the possible labels.
  96. else:
  97. for l in labels:
  98. stream.write(' #')
  99. _write_megam_features(encoding.encode(featureset, l), stream, bernoulli)
  100. # End of the instance.
  101. stream.write('\n')
  102. def parse_megam_weights(s, features_count, explicit=True):
  103. """
  104. Given the stdout output generated by ``megam`` when training a
  105. model, return a ``numpy`` array containing the corresponding weight
  106. vector. This function does not currently handle bias features.
  107. """
  108. if numpy is None:
  109. raise ValueError('This function requires that numpy be installed')
  110. assert explicit, 'non-explicit not supported yet'
  111. lines = s.strip().split('\n')
  112. weights = numpy.zeros(features_count, 'd')
  113. for line in lines:
  114. if line.strip():
  115. fid, weight = line.split()
  116. weights[int(fid)] = float(weight)
  117. return weights
  118. def _write_megam_features(vector, stream, bernoulli):
  119. if not vector:
  120. raise ValueError(
  121. 'MEGAM classifier requires the use of an ' 'always-on feature.'
  122. )
  123. for (fid, fval) in vector:
  124. if bernoulli:
  125. if fval == 1:
  126. stream.write(' %s' % fid)
  127. elif fval != 0:
  128. raise ValueError(
  129. 'If bernoulli=True, then all' 'features must be binary.'
  130. )
  131. else:
  132. stream.write(' %s %s' % (fid, fval))
  133. def call_megam(args):
  134. """
  135. Call the ``megam`` binary with the given arguments.
  136. """
  137. if isinstance(args, string_types):
  138. raise TypeError('args should be a list of strings')
  139. if _megam_bin is None:
  140. config_megam()
  141. # Call megam via a subprocess
  142. cmd = [_megam_bin] + args
  143. p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
  144. (stdout, stderr) = p.communicate()
  145. # Check the return code.
  146. if p.returncode != 0:
  147. print()
  148. print(stderr)
  149. raise OSError('megam command failed!')
  150. if isinstance(stdout, string_types):
  151. return stdout
  152. else:
  153. return stdout.decode('utf-8')