weka.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Natural Language Toolkit: Interface to Weka Classsifiers
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. Classifiers that make use of the external 'Weka' package.
  9. """
  10. from __future__ import print_function
  11. import time
  12. import tempfile
  13. import os
  14. import subprocess
  15. import re
  16. import zipfile
  17. from sys import stdin
  18. from six import integer_types, string_types
  19. from nltk.probability import DictionaryProbDist
  20. from nltk.internals import java, config_java
  21. from nltk.classify.api import ClassifierI
  22. _weka_classpath = None
  23. _weka_search = [
  24. '.',
  25. '/usr/share/weka',
  26. '/usr/local/share/weka',
  27. '/usr/lib/weka',
  28. '/usr/local/lib/weka',
  29. ]
  30. def config_weka(classpath=None):
  31. global _weka_classpath
  32. # Make sure java's configured first.
  33. config_java()
  34. if classpath is not None:
  35. _weka_classpath = classpath
  36. if _weka_classpath is None:
  37. searchpath = _weka_search
  38. if 'WEKAHOME' in os.environ:
  39. searchpath.insert(0, os.environ['WEKAHOME'])
  40. for path in searchpath:
  41. if os.path.exists(os.path.join(path, 'weka.jar')):
  42. _weka_classpath = os.path.join(path, 'weka.jar')
  43. version = _check_weka_version(_weka_classpath)
  44. if version:
  45. print(
  46. ('[Found Weka: %s (version %s)]' % (_weka_classpath, version))
  47. )
  48. else:
  49. print('[Found Weka: %s]' % _weka_classpath)
  50. _check_weka_version(_weka_classpath)
  51. if _weka_classpath is None:
  52. raise LookupError(
  53. 'Unable to find weka.jar! Use config_weka() '
  54. 'or set the WEKAHOME environment variable. '
  55. 'For more information about Weka, please see '
  56. 'http://www.cs.waikato.ac.nz/ml/weka/'
  57. )
  58. def _check_weka_version(jar):
  59. try:
  60. zf = zipfile.ZipFile(jar)
  61. except (SystemExit, KeyboardInterrupt):
  62. raise
  63. except:
  64. return None
  65. try:
  66. try:
  67. return zf.read('weka/core/version.txt')
  68. except KeyError:
  69. return None
  70. finally:
  71. zf.close()
  72. class WekaClassifier(ClassifierI):
  73. def __init__(self, formatter, model_filename):
  74. self._formatter = formatter
  75. self._model = model_filename
  76. def prob_classify_many(self, featuresets):
  77. return self._classify_many(featuresets, ['-p', '0', '-distribution'])
  78. def classify_many(self, featuresets):
  79. return self._classify_many(featuresets, ['-p', '0'])
  80. def _classify_many(self, featuresets, options):
  81. # Make sure we can find java & weka.
  82. config_weka()
  83. temp_dir = tempfile.mkdtemp()
  84. try:
  85. # Write the test data file.
  86. test_filename = os.path.join(temp_dir, 'test.arff')
  87. self._formatter.write(test_filename, featuresets)
  88. # Call weka to classify the data.
  89. cmd = [
  90. 'weka.classifiers.bayes.NaiveBayes',
  91. '-l',
  92. self._model,
  93. '-T',
  94. test_filename,
  95. ] + options
  96. (stdout, stderr) = java(
  97. cmd,
  98. classpath=_weka_classpath,
  99. stdout=subprocess.PIPE,
  100. stderr=subprocess.PIPE,
  101. )
  102. # Check if something went wrong:
  103. if stderr and not stdout:
  104. if 'Illegal options: -distribution' in stderr:
  105. raise ValueError(
  106. 'The installed version of weka does '
  107. 'not support probability distribution '
  108. 'output.'
  109. )
  110. else:
  111. raise ValueError('Weka failed to generate output:\n%s' % stderr)
  112. # Parse weka's output.
  113. return self.parse_weka_output(stdout.decode(stdin.encoding).split('\n'))
  114. finally:
  115. for f in os.listdir(temp_dir):
  116. os.remove(os.path.join(temp_dir, f))
  117. os.rmdir(temp_dir)
  118. def parse_weka_distribution(self, s):
  119. probs = [float(v) for v in re.split('[*,]+', s) if v.strip()]
  120. probs = dict(zip(self._formatter.labels(), probs))
  121. return DictionaryProbDist(probs)
  122. def parse_weka_output(self, lines):
  123. # Strip unwanted text from stdout
  124. for i, line in enumerate(lines):
  125. if line.strip().startswith("inst#"):
  126. lines = lines[i:]
  127. break
  128. if lines[0].split() == ['inst#', 'actual', 'predicted', 'error', 'prediction']:
  129. return [line.split()[2].split(':')[1] for line in lines[1:] if line.strip()]
  130. elif lines[0].split() == [
  131. 'inst#',
  132. 'actual',
  133. 'predicted',
  134. 'error',
  135. 'distribution',
  136. ]:
  137. return [
  138. self.parse_weka_distribution(line.split()[-1])
  139. for line in lines[1:]
  140. if line.strip()
  141. ]
  142. # is this safe:?
  143. elif re.match(r'^0 \w+ [01]\.[0-9]* \?\s*$', lines[0]):
  144. return [line.split()[1] for line in lines if line.strip()]
  145. else:
  146. for line in lines[:10]:
  147. print(line)
  148. raise ValueError(
  149. 'Unhandled output format -- your version '
  150. 'of weka may not be supported.\n'
  151. ' Header: %s' % lines[0]
  152. )
  153. # [xx] full list of classifiers (some may be abstract?):
  154. # ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule,
  155. # DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48,
  156. # JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic,
  157. # LogisticBase, M5Base, MultilayerPerceptron,
  158. # MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial,
  159. # NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART,
  160. # PreConstructedLinearModel, Prism, RandomForest,
  161. # RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor,
  162. # RuleNode, SimpleLinearRegression, SimpleLogistic,
  163. # SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI,
  164. # VotedPerceptron, Winnow, ZeroR
  165. _CLASSIFIER_CLASS = {
  166. 'naivebayes': 'weka.classifiers.bayes.NaiveBayes',
  167. 'C4.5': 'weka.classifiers.trees.J48',
  168. 'log_regression': 'weka.classifiers.functions.Logistic',
  169. 'svm': 'weka.classifiers.functions.SMO',
  170. 'kstar': 'weka.classifiers.lazy.KStar',
  171. 'ripper': 'weka.classifiers.rules.JRip',
  172. }
  173. @classmethod
  174. def train(
  175. cls,
  176. model_filename,
  177. featuresets,
  178. classifier='naivebayes',
  179. options=[],
  180. quiet=True,
  181. ):
  182. # Make sure we can find java & weka.
  183. config_weka()
  184. # Build an ARFF formatter.
  185. formatter = ARFF_Formatter.from_train(featuresets)
  186. temp_dir = tempfile.mkdtemp()
  187. try:
  188. # Write the training data file.
  189. train_filename = os.path.join(temp_dir, 'train.arff')
  190. formatter.write(train_filename, featuresets)
  191. if classifier in cls._CLASSIFIER_CLASS:
  192. javaclass = cls._CLASSIFIER_CLASS[classifier]
  193. elif classifier in cls._CLASSIFIER_CLASS.values():
  194. javaclass = classifier
  195. else:
  196. raise ValueError('Unknown classifier %s' % classifier)
  197. # Train the weka model.
  198. cmd = [javaclass, '-d', model_filename, '-t', train_filename]
  199. cmd += list(options)
  200. if quiet:
  201. stdout = subprocess.PIPE
  202. else:
  203. stdout = None
  204. java(cmd, classpath=_weka_classpath, stdout=stdout)
  205. # Return the new classifier.
  206. return WekaClassifier(formatter, model_filename)
  207. finally:
  208. for f in os.listdir(temp_dir):
  209. os.remove(os.path.join(temp_dir, f))
  210. os.rmdir(temp_dir)
  211. class ARFF_Formatter:
  212. """
  213. Converts featuresets and labeled featuresets to ARFF-formatted
  214. strings, appropriate for input into Weka.
  215. Features and classes can be specified manually in the constructor, or may
  216. be determined from data using ``from_train``.
  217. """
  218. def __init__(self, labels, features):
  219. """
  220. :param labels: A list of all class labels that can be generated.
  221. :param features: A list of feature specifications, where
  222. each feature specification is a tuple (fname, ftype);
  223. and ftype is an ARFF type string such as NUMERIC or
  224. STRING.
  225. """
  226. self._labels = labels
  227. self._features = features
  228. def format(self, tokens):
  229. """Returns a string representation of ARFF output for the given data."""
  230. return self.header_section() + self.data_section(tokens)
  231. def labels(self):
  232. """Returns the list of classes."""
  233. return list(self._labels)
  234. def write(self, outfile, tokens):
  235. """Writes ARFF data to a file for the given data."""
  236. if not hasattr(outfile, 'write'):
  237. outfile = open(outfile, 'w')
  238. outfile.write(self.format(tokens))
  239. outfile.close()
  240. @staticmethod
  241. def from_train(tokens):
  242. """
  243. Constructs an ARFF_Formatter instance with class labels and feature
  244. types determined from the given data. Handles boolean, numeric and
  245. string (note: not nominal) types.
  246. """
  247. # Find the set of all attested labels.
  248. labels = set(label for (tok, label) in tokens)
  249. # Determine the types of all features.
  250. features = {}
  251. for tok, label in tokens:
  252. for (fname, fval) in tok.items():
  253. if issubclass(type(fval), bool):
  254. ftype = '{True, False}'
  255. elif issubclass(type(fval), (integer_types, float, bool)):
  256. ftype = 'NUMERIC'
  257. elif issubclass(type(fval), string_types):
  258. ftype = 'STRING'
  259. elif fval is None:
  260. continue # can't tell the type.
  261. else:
  262. raise ValueError('Unsupported value type %r' % ftype)
  263. if features.get(fname, ftype) != ftype:
  264. raise ValueError('Inconsistent type for %s' % fname)
  265. features[fname] = ftype
  266. features = sorted(features.items())
  267. return ARFF_Formatter(labels, features)
  268. def header_section(self):
  269. """Returns an ARFF header as a string."""
  270. # Header comment.
  271. s = (
  272. '% Weka ARFF file\n'
  273. + '% Generated automatically by NLTK\n'
  274. + '%% %s\n\n' % time.ctime()
  275. )
  276. # Relation name
  277. s += '@RELATION rel\n\n'
  278. # Input attribute specifications
  279. for fname, ftype in self._features:
  280. s += '@ATTRIBUTE %-30r %s\n' % (fname, ftype)
  281. # Label attribute specification
  282. s += '@ATTRIBUTE %-30r {%s}\n' % ('-label-', ','.join(self._labels))
  283. return s
  284. def data_section(self, tokens, labeled=None):
  285. """
  286. Returns the ARFF data section for the given data.
  287. :param tokens: a list of featuresets (dicts) or labelled featuresets
  288. which are tuples (featureset, label).
  289. :param labeled: Indicates whether the given tokens are labeled
  290. or not. If None, then the tokens will be assumed to be
  291. labeled if the first token's value is a tuple or list.
  292. """
  293. # Check if the tokens are labeled or unlabeled. If unlabeled,
  294. # then use 'None'
  295. if labeled is None:
  296. labeled = tokens and isinstance(tokens[0], (tuple, list))
  297. if not labeled:
  298. tokens = [(tok, None) for tok in tokens]
  299. # Data section
  300. s = '\n@DATA\n'
  301. for (tok, label) in tokens:
  302. for fname, ftype in self._features:
  303. s += '%s,' % self._fmt_arff_val(tok.get(fname))
  304. s += '%s\n' % self._fmt_arff_val(label)
  305. return s
  306. def _fmt_arff_val(self, fval):
  307. if fval is None:
  308. return '?'
  309. elif isinstance(fval, (bool, integer_types)):
  310. return '%s' % fval
  311. elif isinstance(fval, float):
  312. return '%r' % fval
  313. else:
  314. return '%r' % fval
  315. if __name__ == '__main__':
  316. from nltk.classify.util import names_demo, binary_names_demo_features
  317. def make_classifier(featuresets):
  318. return WekaClassifier.train('/tmp/name.model', featuresets, 'C4.5')
  319. classifier = names_demo(make_classifier, binary_names_demo_features)