tadm.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Natural Language Toolkit: Interface to TADM Classifier
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Joseph Frazee <jfrazee@mail.utexas.edu>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. from __future__ import print_function, unicode_literals
  8. import sys
  9. import subprocess
  10. from six import string_types
  11. from nltk.internals import find_binary
  12. try:
  13. import numpy
  14. except ImportError:
  15. pass
  16. _tadm_bin = None
  17. def config_tadm(bin=None):
  18. global _tadm_bin
  19. _tadm_bin = find_binary(
  20. 'tadm', bin, env_vars=['TADM'], binary_names=['tadm'], url='http://tadm.sf.net'
  21. )
  22. def write_tadm_file(train_toks, encoding, stream):
  23. """
  24. Generate an input file for ``tadm`` based on the given corpus of
  25. classified tokens.
  26. :type train_toks: list(tuple(dict, str))
  27. :param train_toks: Training data, represented as a list of
  28. pairs, the first member of which is a feature dictionary,
  29. and the second of which is a classification label.
  30. :type encoding: TadmEventMaxentFeatureEncoding
  31. :param encoding: A feature encoding, used to convert featuresets
  32. into feature vectors.
  33. :type stream: stream
  34. :param stream: The stream to which the ``tadm`` input file should be
  35. written.
  36. """
  37. # See the following for a file format description:
  38. #
  39. # http://sf.net/forum/forum.php?thread_id=1391502&forum_id=473054
  40. # http://sf.net/forum/forum.php?thread_id=1675097&forum_id=473054
  41. labels = encoding.labels()
  42. for featureset, label in train_toks:
  43. length_line = '%d\n' % len(labels)
  44. stream.write(length_line)
  45. for known_label in labels:
  46. v = encoding.encode(featureset, known_label)
  47. line = '%d %d %s\n' % (
  48. int(label == known_label),
  49. len(v),
  50. ' '.join('%d %d' % u for u in v),
  51. )
  52. stream.write(line)
  53. def parse_tadm_weights(paramfile):
  54. """
  55. Given the stdout output generated by ``tadm`` when training a
  56. model, return a ``numpy`` array containing the corresponding weight
  57. vector.
  58. """
  59. weights = []
  60. for line in paramfile:
  61. weights.append(float(line.strip()))
  62. return numpy.array(weights, 'd')
  63. def call_tadm(args):
  64. """
  65. Call the ``tadm`` binary with the given arguments.
  66. """
  67. if isinstance(args, string_types):
  68. raise TypeError('args should be a list of strings')
  69. if _tadm_bin is None:
  70. config_tadm()
  71. # Call tadm via a subprocess
  72. cmd = [_tadm_bin] + args
  73. p = subprocess.Popen(cmd, stdout=sys.stdout)
  74. (stdout, stderr) = p.communicate()
  75. # Check the return code.
  76. if p.returncode != 0:
  77. print()
  78. print(stderr)
  79. raise OSError('tadm command failed!')
  80. def names_demo():
  81. from nltk.classify.util import names_demo
  82. from nltk.classify.maxent import TadmMaxentClassifier
  83. classifier = names_demo(TadmMaxentClassifier.train)
  84. def encoding_demo():
  85. import sys
  86. from nltk.classify.maxent import TadmEventMaxentFeatureEncoding
  87. tokens = [
  88. ({'f0': 1, 'f1': 1, 'f3': 1}, 'A'),
  89. ({'f0': 1, 'f2': 1, 'f4': 1}, 'B'),
  90. ({'f0': 2, 'f2': 1, 'f3': 1, 'f4': 1}, 'A'),
  91. ]
  92. encoding = TadmEventMaxentFeatureEncoding.train(tokens)
  93. write_tadm_file(tokens, encoding, sys.stdout)
  94. print()
  95. for i in range(encoding.length()):
  96. print('%s --> %d' % (encoding.describe(i), i))
  97. print()
  98. if __name__ == '__main__':
  99. encoding_demo()
  100. names_demo()