"""Hierarchical multi-label wrappers (standard methods to transform hierarchical multi-label problems into single label-problems). """ from diles.learn import Learner, Label from diles import stats from diles.learn.util import sugconf class _Tree(object): """Tree representing class taxonomy of labels. Creates binary labels and a corresponding learner for each node. """ def __init__(self, cls, params, samples, labels, fwl, fwt, path=(), seen=None): """The base learner class is given by `cls`. Corresponding learners are instantiated using `params`. The parameters `samples`, `labels`, `fwl` and `fwt` are those passed to the train method of the base learner instances (see :meth:`diles.learn.Learner.train`). If this tree is a subtree, `path` must be a tuple of node names, providing the path from the global tree root to this subtree. Keyword `seen` is used internally only and initialized by the root node to remember which label paths already have been traversed. """ self.samples = samples self.fwl = fwl self.fwt = fwt self.path = path self.children = [] # learners for direct child paths blabels = [] # binary labels for our path level = len(self.path) # our path's level # remember seen nodes seen = set() if seen is None else seen seen.add(path) # check which labels the current path is a part of and create child # learners for direct child paths for l in labels: found = False for item in l: item = tuple(item.split(".")) if self.path == item[:level]: found = True item = item[:level+1] # only consider direct children if item != self.path and item not in seen: self.children.append(_Tree(cls, params, samples, labels, fwl, fwt, item, seen)) blabels.append(found) self.children.sort(key=lambda x: x.path) # create and train a learner for the current path self.learner = cls(**params) self.learner.train(samples, blabels, fwl, fwt) def predict(self, sample, debug=False): """Predict a label by hierarchically consulting base learners.""" use, conf, _ = self.learner.predict(sample) if debug and self.path: print ("%s: %s (%2.2f)" % (".".join(self.path), use, conf)) if use: predictions = set() confs = [] for c in self.children: items, cconf = c.predict(sample, debug) predictions |= items confs.append(cconf) predictions = Label(predictions or [".".join(self.path)]) confs = confs or [conf] return Label(predictions), stats.mean(confs) else: return Label([]), conf def __str__(self): s = self.path and [".".join(self.path)] or [] s.extend([str(c) for c in self.children]) return ", ".join(s) def __repr__(self): return str(self) class TreeWrapper(Learner): """Hierarchical multi-label learner which combines binary base learner. Base learners are organized in a tree representing the label hierarchies. Each node represents a hierarchy level and trains a binary base learner for the node's label. Implements the approach of "Binarized Structured Label Learning" by `Wu (2005)`_. .. _Wu (2005): http://dx.doi.org/10.1007/11527862_24 """ id = "HBR" paramspace = {} def __init__(self, cls, **params): """The base learner class is given by `cls`. Instances are created using `params`. """ super(TreeWrapper, self).__init__() self.cls = cls self.params = params def train(self, samples, labels, fwl=None, fwt=0): """See :meth:`diles.learn.Learner.train`.""" labels = [Label("%s.*" % e for e in l) for l in labels] self.tree = _Tree(self.cls, self.params, samples, labels, fwl, fwt) def predict(self, sample, debug=False): # pylint: disable-msg=W0221 """See :meth:`diles.learn.Learner.predict` No ranking is returned (just ``None``). """ label, conf = self.tree.predict(sample, debug) if debug: print(self.tree) print "raw label:", label guess = label and not all(e.endswith(".*") for e in label) label = Label(e[:-2] if e.endswith(".*") else e for e in label) return label, sugconf(conf) if guess else conf, None # ============================================================================= # tests # ============================================================================= def __doctests(): """ >>> from diles.learn.learners import NBLearner, FMLearner >>> samples = "00", "01", "22", "02", "09" >>> labels = ["a.b", "x.y"], ["a.b", "x.z"], ["m.n.o.p"], ["m.n", "a"], ["a"] >>> labels = [Label(l) for l in labels] >>> l = TreeWrapper(NBLearner) >>> l.train(samples, labels) >>> l.predict("00", debug=True) a: True (0.94) a.*: False (0.60) a.b: True (0.72) a.b.*: True (0.72) m: False (0.83) x: True (0.72) x.y: True (0.50) x.y.*: True (0.50) x.z: False (0.78) a, a.*, a.b, a.b.*, m, m.n, m.n.*, m.n.o, m.n.o.p, m.n.o.p.*, x, x.y, x.y.*, x.z, x.z.* raw label: {'a.b.*', 'x.y.*'} ({'a.b', 'x.y'}, 0.71..., None) >>> l.predict("01", debug=True) a: True (0.94) a.*: False (0.60) a.b: True (0.72) a.b.*: True (0.72) m: False (0.83) x: True (0.72) x.y: False (0.78) x.z: True (0.50) x.z.*: True (0.50) a, a.*, a.b, a.b.*, m, m.n, m.n.*, m.n.o, m.n.o.p, m.n.o.p.*, x, x.y, x.y.*, x.z, x.z.* raw label: {'a.b.*', 'x.z.*'} ({'a.b', 'x.z'}, 0.71..., None) >>> l.predict("02", debug=True) a: True (0.83) a.*: True (0.16) a.b: False (0.76) m: True (0.61) m.n: True (0.61) m.n.*: False (0.33) m.n.o: False (0.83) x: False (0.76) a, a.*, a.b, a.b.*, m, m.n, m.n.*, m.n.o, m.n.o.p, m.n.o.p.*, x, x.y, x.y.*, x.z, x.z.* raw label: {'a.*', 'm.n'} ({'a', 'm.n'}, ?0.60..., None) >>> l.predict("03", debug=True) a: True (0.83) a.*: True (0.16) a.b: True (0.16) a.b.*: True (0.16) m: False (0.49) x: True (0.16) x.y: False (0.33) x.z: False (0.33) a, a.*, a.b, a.b.*, m, m.n, m.n.*, m.n.o, m.n.o.p, m.n.o.p.*, x, x.y, x.y.*, x.z, x.z.* raw label: {'a.*', 'a.b.*', 'x'} ({'a', 'a.b', 'x'}, ?0.32..., None) >>> from diles.tests import TSETS >>> samples, labels = TSETS['tricky-09-balanced'] >>> l = TreeWrapper(NBLearner) >>> l.train(samples, labels) >>> for i in range(5): ... print l.predict(samples[i], debug=False), labels[i] ({'x', 'y', 'z'}, 0.986..., None) {'x', 'y', 'z'} ({'x'}, 0.982..., None) {'x'} ({'x', 'y', 'z'}, 0.899..., None) {'x', 'y', 'z'} ({'x', 'z'}, 0.771..., None) {'x', 'z'} ({'x', 'z'}, 0.795..., None) {'x', 'z'} """