123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- # Natural Language Toolkit: Interface to Weka Classsifiers
- #
- # Copyright (C) 2001-2019 NLTK Project
- # Author: Edward Loper <edloper@gmail.com>
- # URL: <http://nltk.org/>
- # For license information, see LICENSE.TXT
- """
- Classifiers that make use of the external 'Weka' package.
- """
- from __future__ import print_function
- import time
- import tempfile
- import os
- import subprocess
- import re
- import zipfile
- from sys import stdin
- from six import integer_types, string_types
- from nltk.probability import DictionaryProbDist
- from nltk.internals import java, config_java
- from nltk.classify.api import ClassifierI
- _weka_classpath = None
- _weka_search = [
- '.',
- '/usr/share/weka',
- '/usr/local/share/weka',
- '/usr/lib/weka',
- '/usr/local/lib/weka',
- ]
- def config_weka(classpath=None):
- global _weka_classpath
- # Make sure java's configured first.
- config_java()
- if classpath is not None:
- _weka_classpath = classpath
- if _weka_classpath is None:
- searchpath = _weka_search
- if 'WEKAHOME' in os.environ:
- searchpath.insert(0, os.environ['WEKAHOME'])
- for path in searchpath:
- if os.path.exists(os.path.join(path, 'weka.jar')):
- _weka_classpath = os.path.join(path, 'weka.jar')
- version = _check_weka_version(_weka_classpath)
- if version:
- print(
- ('[Found Weka: %s (version %s)]' % (_weka_classpath, version))
- )
- else:
- print('[Found Weka: %s]' % _weka_classpath)
- _check_weka_version(_weka_classpath)
- if _weka_classpath is None:
- raise LookupError(
- 'Unable to find weka.jar! Use config_weka() '
- 'or set the WEKAHOME environment variable. '
- 'For more information about Weka, please see '
- 'http://www.cs.waikato.ac.nz/ml/weka/'
- )
- def _check_weka_version(jar):
- try:
- zf = zipfile.ZipFile(jar)
- except (SystemExit, KeyboardInterrupt):
- raise
- except:
- return None
- try:
- try:
- return zf.read('weka/core/version.txt')
- except KeyError:
- return None
- finally:
- zf.close()
- class WekaClassifier(ClassifierI):
- def __init__(self, formatter, model_filename):
- self._formatter = formatter
- self._model = model_filename
- def prob_classify_many(self, featuresets):
- return self._classify_many(featuresets, ['-p', '0', '-distribution'])
- def classify_many(self, featuresets):
- return self._classify_many(featuresets, ['-p', '0'])
- def _classify_many(self, featuresets, options):
- # Make sure we can find java & weka.
- config_weka()
- temp_dir = tempfile.mkdtemp()
- try:
- # Write the test data file.
- test_filename = os.path.join(temp_dir, 'test.arff')
- self._formatter.write(test_filename, featuresets)
- # Call weka to classify the data.
- cmd = [
- 'weka.classifiers.bayes.NaiveBayes',
- '-l',
- self._model,
- '-T',
- test_filename,
- ] + options
- (stdout, stderr) = java(
- cmd,
- classpath=_weka_classpath,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
- # Check if something went wrong:
- if stderr and not stdout:
- if 'Illegal options: -distribution' in stderr:
- raise ValueError(
- 'The installed version of weka does '
- 'not support probability distribution '
- 'output.'
- )
- else:
- raise ValueError('Weka failed to generate output:\n%s' % stderr)
- # Parse weka's output.
- return self.parse_weka_output(stdout.decode(stdin.encoding).split('\n'))
- finally:
- for f in os.listdir(temp_dir):
- os.remove(os.path.join(temp_dir, f))
- os.rmdir(temp_dir)
- def parse_weka_distribution(self, s):
- probs = [float(v) for v in re.split('[*,]+', s) if v.strip()]
- probs = dict(zip(self._formatter.labels(), probs))
- return DictionaryProbDist(probs)
- def parse_weka_output(self, lines):
- # Strip unwanted text from stdout
- for i, line in enumerate(lines):
- if line.strip().startswith("inst#"):
- lines = lines[i:]
- break
- if lines[0].split() == ['inst#', 'actual', 'predicted', 'error', 'prediction']:
- return [line.split()[2].split(':')[1] for line in lines[1:] if line.strip()]
- elif lines[0].split() == [
- 'inst#',
- 'actual',
- 'predicted',
- 'error',
- 'distribution',
- ]:
- return [
- self.parse_weka_distribution(line.split()[-1])
- for line in lines[1:]
- if line.strip()
- ]
- # is this safe:?
- elif re.match(r'^0 \w+ [01]\.[0-9]* \?\s*$', lines[0]):
- return [line.split()[1] for line in lines if line.strip()]
- else:
- for line in lines[:10]:
- print(line)
- raise ValueError(
- 'Unhandled output format -- your version '
- 'of weka may not be supported.\n'
- ' Header: %s' % lines[0]
- )
- # [xx] full list of classifiers (some may be abstract?):
- # ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule,
- # DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48,
- # JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic,
- # LogisticBase, M5Base, MultilayerPerceptron,
- # MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial,
- # NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART,
- # PreConstructedLinearModel, Prism, RandomForest,
- # RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor,
- # RuleNode, SimpleLinearRegression, SimpleLogistic,
- # SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI,
- # VotedPerceptron, Winnow, ZeroR
- _CLASSIFIER_CLASS = {
- 'naivebayes': 'weka.classifiers.bayes.NaiveBayes',
- 'C4.5': 'weka.classifiers.trees.J48',
- 'log_regression': 'weka.classifiers.functions.Logistic',
- 'svm': 'weka.classifiers.functions.SMO',
- 'kstar': 'weka.classifiers.lazy.KStar',
- 'ripper': 'weka.classifiers.rules.JRip',
- }
- @classmethod
- def train(
- cls,
- model_filename,
- featuresets,
- classifier='naivebayes',
- options=[],
- quiet=True,
- ):
- # Make sure we can find java & weka.
- config_weka()
- # Build an ARFF formatter.
- formatter = ARFF_Formatter.from_train(featuresets)
- temp_dir = tempfile.mkdtemp()
- try:
- # Write the training data file.
- train_filename = os.path.join(temp_dir, 'train.arff')
- formatter.write(train_filename, featuresets)
- if classifier in cls._CLASSIFIER_CLASS:
- javaclass = cls._CLASSIFIER_CLASS[classifier]
- elif classifier in cls._CLASSIFIER_CLASS.values():
- javaclass = classifier
- else:
- raise ValueError('Unknown classifier %s' % classifier)
- # Train the weka model.
- cmd = [javaclass, '-d', model_filename, '-t', train_filename]
- cmd += list(options)
- if quiet:
- stdout = subprocess.PIPE
- else:
- stdout = None
- java(cmd, classpath=_weka_classpath, stdout=stdout)
- # Return the new classifier.
- return WekaClassifier(formatter, model_filename)
- finally:
- for f in os.listdir(temp_dir):
- os.remove(os.path.join(temp_dir, f))
- os.rmdir(temp_dir)
- class ARFF_Formatter:
- """
- Converts featuresets and labeled featuresets to ARFF-formatted
- strings, appropriate for input into Weka.
- Features and classes can be specified manually in the constructor, or may
- be determined from data using ``from_train``.
- """
- def __init__(self, labels, features):
- """
- :param labels: A list of all class labels that can be generated.
- :param features: A list of feature specifications, where
- each feature specification is a tuple (fname, ftype);
- and ftype is an ARFF type string such as NUMERIC or
- STRING.
- """
- self._labels = labels
- self._features = features
- def format(self, tokens):
- """Returns a string representation of ARFF output for the given data."""
- return self.header_section() + self.data_section(tokens)
- def labels(self):
- """Returns the list of classes."""
- return list(self._labels)
- def write(self, outfile, tokens):
- """Writes ARFF data to a file for the given data."""
- if not hasattr(outfile, 'write'):
- outfile = open(outfile, 'w')
- outfile.write(self.format(tokens))
- outfile.close()
- @staticmethod
- def from_train(tokens):
- """
- Constructs an ARFF_Formatter instance with class labels and feature
- types determined from the given data. Handles boolean, numeric and
- string (note: not nominal) types.
- """
- # Find the set of all attested labels.
- labels = set(label for (tok, label) in tokens)
- # Determine the types of all features.
- features = {}
- for tok, label in tokens:
- for (fname, fval) in tok.items():
- if issubclass(type(fval), bool):
- ftype = '{True, False}'
- elif issubclass(type(fval), (integer_types, float, bool)):
- ftype = 'NUMERIC'
- elif issubclass(type(fval), string_types):
- ftype = 'STRING'
- elif fval is None:
- continue # can't tell the type.
- else:
- raise ValueError('Unsupported value type %r' % ftype)
- if features.get(fname, ftype) != ftype:
- raise ValueError('Inconsistent type for %s' % fname)
- features[fname] = ftype
- features = sorted(features.items())
- return ARFF_Formatter(labels, features)
- def header_section(self):
- """Returns an ARFF header as a string."""
- # Header comment.
- s = (
- '% Weka ARFF file\n'
- + '% Generated automatically by NLTK\n'
- + '%% %s\n\n' % time.ctime()
- )
- # Relation name
- s += '@RELATION rel\n\n'
- # Input attribute specifications
- for fname, ftype in self._features:
- s += '@ATTRIBUTE %-30r %s\n' % (fname, ftype)
- # Label attribute specification
- s += '@ATTRIBUTE %-30r {%s}\n' % ('-label-', ','.join(self._labels))
- return s
- def data_section(self, tokens, labeled=None):
- """
- Returns the ARFF data section for the given data.
- :param tokens: a list of featuresets (dicts) or labelled featuresets
- which are tuples (featureset, label).
- :param labeled: Indicates whether the given tokens are labeled
- or not. If None, then the tokens will be assumed to be
- labeled if the first token's value is a tuple or list.
- """
- # Check if the tokens are labeled or unlabeled. If unlabeled,
- # then use 'None'
- if labeled is None:
- labeled = tokens and isinstance(tokens[0], (tuple, list))
- if not labeled:
- tokens = [(tok, None) for tok in tokens]
- # Data section
- s = '\n@DATA\n'
- for (tok, label) in tokens:
- for fname, ftype in self._features:
- s += '%s,' % self._fmt_arff_val(tok.get(fname))
- s += '%s\n' % self._fmt_arff_val(label)
- return s
- def _fmt_arff_val(self, fval):
- if fval is None:
- return '?'
- elif isinstance(fval, (bool, integer_types)):
- return '%s' % fval
- elif isinstance(fval, float):
- return '%r' % fval
- else:
- return '%r' % fval
- if __name__ == '__main__':
- from nltk.classify.util import names_demo, binary_names_demo_features
- def make_classifier(featuresets):
- return WekaClassifier.train('/tmp/name.model', featuresets, 'C4.5')
- classifier = names_demo(make_classifier, binary_names_demo_features)
|