Source code for pysteg.ml.svm

#! /usr/bin/env python
## -*- coding: utf-8 -*-
## (C) 2012: Hans Georg Schaathun <georg@schaathun.net> 

"""
Support Vector Machines with Grid Search.

:Module:    pysteg.ml.svm
:Date:      $Date$
:Revision:  $Revision$
:Author:    © 2012: Hans Georg Schaathun <georg@schaathun.net>
:Copyright: a small part of the code is derived from code
    by (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin.
    See source code for details.

This provides a wrapper for the LibSvm, where a grid search is
performed as an integral part of training.  Thus gamma and C do
not have to be specified at initialisation time; instead search
ranges are provided and the training method finds the optimal
values within the range.

The module also exports the auxiliary functions to do grid
search and cross-validation.
"""

import sys, traceback
from threading import Thread
import Queue as queue
import mlpy
import numpy as np

from .classifiers import Classifier

__all__ = [ "SVM", "gridsearch", "xvalidate" ]

[docs]class SVM(Classifier): """Support Vector Machine. This is a wrapper for mlpy.LibSvm with additional functionality. """ def __init__(self,k,crange=( -5, 15, 2),grange=( 3, -15, -2), nworker = 1, gnuplot=None, verbosity=0, **kw): # svm_type=’c_svc’, kernel_type=’rbf’, degree=3, # gamma=0.001, coef0=0, C=1, nu=0.5, eps=0.001, # p=0.1, cache_size=100, shrinking=True, probability=False, # weight={}) self.k = k self.nworker = nworker self.gnuplot = gnuplot self.verbosity = verbosity self.crange = crange self.grange = grange self.libsvm = None self.param = kw def learn(self,x,y): if ( self.param["svm_type"] == "c_svc" and self.param["kernel_type"] != "linear" ): (db,(best_c1, best_g1),(best_c, best_g), best_rate) = \ gridsearch(x,y,k,self.crange,self.grange, self.nworker,self.gnuplot,self.verbosity) self.param["gamma"] = best_g self.param["C"] = best_c self.libsvm = LibSvm( **self.param ) return self.libsvm.learn(x,y) def pred(self,*a,**kw): return self.libsvm.pred(*a,**kw) def pred_probability(self,*a,**kw): return self.libsvm.pred_probability(*a,**kw) def pred_values(self,*a,**kw): return self.libsvm.pred_values(*a,**kw) def labels(self,*a,**kw): return self.libsvm.labels(*a,**kw) def nclasses(self,*a,**kw): return self.libsvm.nclasses(*a,**kw) def nsv(self,*a,**kw): return self.libsvm.nsv(*a,**kw) def label_nsv(self,*a,**kw): return self.libsvm.label_nsv(*a,**kw) @classmethod def load_model(cls,*a,**kw): C = cls() C.libsvm = LibSvm.load_model(*a,**kw) return C def save_model(self,*a,**kw): return self.libsvm.save_model(*a,**kw)
def range_f(begin,end,step): """like range, but works on non-integer too. Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin """ seq = [] while True: if step > 0 and begin > end: break if step < 0 and begin < end: break seq.append(begin) begin = begin + step return seq def permute_sequence(seq): """Auxiliary function for :func:`calculate_jobs`. Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin """ n = len(seq) if n <= 1: return seq mid = int(n/2) left = permute_sequence(seq[:mid]) right = permute_sequence(seq[mid+1:]) ret = [seq[mid]] while left or right: if left: ret.append(left.pop(0)) if right: ret.append(right.pop(0)) return ret def calculate_jobs(crange,grange): """Return a list of lists, containing all the (c,g) pairs to be checked during the grid search. Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin """ (c_begin,c_end,c_step) = crange (g_begin,g_end,g_step) = grange c_seq = permute_sequence(range_f(c_begin,c_end,c_step)) g_seq = permute_sequence(range_f(g_begin,g_end,g_step)) nr_c = float(len(c_seq)) nr_g = float(len(g_seq)) i = 0 j = 0 jobs = [] while i < nr_c or j < nr_g: if i/nr_c < j/nr_g: # increase C resolution line = [] for k in range(0,j): line.append((c_seq[i],g_seq[k])) i = i + 1 jobs.append(line) else: # increase g resolution line = [] for k in range(0,i): line.append((c_seq[k],g_seq[j])) j = j + 1 jobs.append(line) return jobs class WorkerStopToken: "used to notify the worker to stop" pass class Worker(Thread): """One thread for the execution of the grid search. The search is parallellised with one Worker per parallel thread. Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin """ def __init__(self,name,job_queue,result_queue,x,y,fold,**kw): Thread.__init__(self) self.name = name self.job_queue = job_queue self.result_queue = result_queue self.x = x self.y = y self.fold = fold self.xargs = kw def run(self): while True: (cexp,gexp) = self.job_queue.get() if cexp is WorkerStopToken: self.job_queue.put((cexp,gexp)) # print 'worker %s stop.' % self.name break rate = xvalidate(mlpy.LibSvm,self.fold,self.x,self.y, gamma=2.0**gexp,C=2.0**cexp,**self.xargs) self.result_queue.put((self.name,cexp,gexp,rate)) # The gridsearch() function and its auxiliaries are based on code # from libSVM and the following copyright notice applies thereto: # ## Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin ## All rights reserved. ## ## Redistribution and use in source and binary forms, with or without ## modification, are permitted provided that the following conditions ## are met: ## ## 1. Redistributions of source code must retain the above copyright ## notice, this list of conditions and the following disclaimer. ## ## 2. Redistributions in binary form must reproduce the above copyright ## notice, this list of conditions and the following disclaimer in the ## documentation and/or other materials provided with the distribution. ## ## 3. Neither name of copyright holders nor the names of its contributors ## may be used to endorse or promote products derived from this software ## without specific prior written permission. ## ## ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ## ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT ## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR ## CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, ## EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, ## PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR ## PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF ## LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING ## NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS ## SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[docs]def gridsearch(x,y,fold, crange=( -5, 15, 2), grange=( 3, -15, -2), nr_local_worker = 1, gnuplot=None,verbosity=0,**kw): """ Perform a grid search based on the given test set and options. If gnuplot is given, it should implement the interface of GnuPlot object, and will be used to plot and update a contour plot of parameter choices during the search. The return values is (db,(best_c1, best_g1, best_rate)), where db is a list of all the tested parameters with performance, best_c1 and best_g1 are the optimal parameters (log c and log gamma), and best_rate is the cross-validation accuracy at the optimised parameters. This code is derived from the grid search script distributed with the libSVM library. Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin Copyright (c) 2012 Hans Georg Schaathun """ # put jobs in queue jobs = calculate_jobs(crange,grange) job_queue = queue.Queue(0) result_queue = queue.Queue(0) for line in jobs: for j in line: job_queue.put(j) job_queue._put = job_queue.queue.appendleft # fire local workers for i in range(nr_local_worker): Worker('local',job_queue,result_queue,x,y, fold=fold,**kw).start() # gather results done_jobs = {} db = [] best_rate = -1 best_c1,best_g1 = None,None for line in jobs: for (c,g) in line: while (c, g) not in done_jobs: (worker,c1,g1,rate) = result_queue.get() done_jobs[(c1,g1)] = rate if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1): best_rate = rate best_c1,best_g1=c1,g1 best_c = 2.0**c1 best_g = 2.0**g1 if verbosity > 2: print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \ (worker,c1,g1,rate, best_c, best_g, best_rate)) db.append((c,g,done_jobs[(c,g)])) if gnuplot != None: gnuplot.redraw(db,[best_c1, best_g1, best_rate]) job_queue.put((WorkerStopToken,None)) return (db,(best_c1, best_g1),(best_c, best_g), best_rate)
[docs]def xvalidate(cls,k,x,y,seed=0,**kw): """Cross-validate a model :Parameters: cls : class of the classifier to validate k : integer to perform k-fold crossvalidation x : 2-D array training set feature vectors y : 1-D array training set labels All other keyword arguments are passed to the classifier constructor. """ x = np.asarray(x,dtype=np.float) (N,D) = x.shape idx = mlpy.cv_kfold(N,k,seed=seed) R = 0.0 for tr,ts in idx: C = cls(**kw) C.learn( x[tr,:], y[tr] ) P = C.pred( x[ts,:] ) check = ( P == y[ts] ) R += np.sum(check) return R / N