#! /usr/bin/env python
## -*- coding: utf-8 -*-
## (C) 2012: Hans Georg Schaathun <georg@schaathun.net>
"""
Support Vector Machines with Grid Search.
:Module: pysteg.ml.svm
:Date: $Date$
:Revision: $Revision$
:Author: © 2012: Hans Georg Schaathun <georg@schaathun.net>
:Copyright: a small part of the code is derived from code
by (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin.
See source code for details.
This provides a wrapper for the LibSvm, where a grid search is
performed as an integral part of training. Thus gamma and C do
not have to be specified at initialisation time; instead search
ranges are provided and the training method finds the optimal
values within the range.
The module also exports the auxiliary functions to do grid
search and cross-validation.
"""
import sys, traceback
from threading import Thread
import Queue as queue
import mlpy
import numpy as np
from .classifiers import Classifier
__all__ = [ "SVM", "gridsearch", "xvalidate" ]
[docs]class SVM(Classifier):
"""Support Vector Machine.
This is a wrapper for mlpy.LibSvm with additional functionality.
"""
def __init__(self,k,crange=( -5, 15, 2),grange=( 3, -15, -2),
nworker = 1, gnuplot=None, verbosity=0, **kw):
# svm_type=’c_svc’, kernel_type=’rbf’, degree=3,
# gamma=0.001, coef0=0, C=1, nu=0.5, eps=0.001,
# p=0.1, cache_size=100, shrinking=True, probability=False,
# weight={})
self.k = k
self.nworker = nworker
self.gnuplot = gnuplot
self.verbosity = verbosity
self.crange = crange
self.grange = grange
self.libsvm = None
self.param = kw
def learn(self,x,y):
if ( self.param["svm_type"] == "c_svc" and
self.param["kernel_type"] != "linear" ):
(db,(best_c1, best_g1),(best_c, best_g), best_rate) = \
gridsearch(x,y,k,self.crange,self.grange,
self.nworker,self.gnuplot,self.verbosity)
self.param["gamma"] = best_g
self.param["C"] = best_c
self.libsvm = LibSvm( **self.param )
return self.libsvm.learn(x,y)
def pred(self,*a,**kw):
return self.libsvm.pred(*a,**kw)
def pred_probability(self,*a,**kw):
return self.libsvm.pred_probability(*a,**kw)
def pred_values(self,*a,**kw):
return self.libsvm.pred_values(*a,**kw)
def labels(self,*a,**kw):
return self.libsvm.labels(*a,**kw)
def nclasses(self,*a,**kw):
return self.libsvm.nclasses(*a,**kw)
def nsv(self,*a,**kw):
return self.libsvm.nsv(*a,**kw)
def label_nsv(self,*a,**kw):
return self.libsvm.label_nsv(*a,**kw)
@classmethod
def load_model(cls,*a,**kw):
C = cls()
C.libsvm = LibSvm.load_model(*a,**kw)
return C
def save_model(self,*a,**kw):
return self.libsvm.save_model(*a,**kw)
def range_f(begin,end,step):
"""like range, but works on non-integer too.
Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
"""
seq = []
while True:
if step > 0 and begin > end: break
if step < 0 and begin < end: break
seq.append(begin)
begin = begin + step
return seq
def permute_sequence(seq):
"""Auxiliary function for :func:`calculate_jobs`.
Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
"""
n = len(seq)
if n <= 1: return seq
mid = int(n/2)
left = permute_sequence(seq[:mid])
right = permute_sequence(seq[mid+1:])
ret = [seq[mid]]
while left or right:
if left: ret.append(left.pop(0))
if right: ret.append(right.pop(0))
return ret
def calculate_jobs(crange,grange):
"""Return a list of lists, containing all the
(c,g) pairs to be checked during the grid search.
Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
"""
(c_begin,c_end,c_step) = crange
(g_begin,g_end,g_step) = grange
c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
nr_c = float(len(c_seq))
nr_g = float(len(g_seq))
i = 0
j = 0
jobs = []
while i < nr_c or j < nr_g:
if i/nr_c < j/nr_g:
# increase C resolution
line = []
for k in range(0,j):
line.append((c_seq[i],g_seq[k]))
i = i + 1
jobs.append(line)
else:
# increase g resolution
line = []
for k in range(0,i):
line.append((c_seq[k],g_seq[j]))
j = j + 1
jobs.append(line)
return jobs
class WorkerStopToken:
"used to notify the worker to stop"
pass
class Worker(Thread):
"""One thread for the execution of the grid search.
The search is parallellised with one Worker per parallel thread.
Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
"""
def __init__(self,name,job_queue,result_queue,x,y,fold,**kw):
Thread.__init__(self)
self.name = name
self.job_queue = job_queue
self.result_queue = result_queue
self.x = x
self.y = y
self.fold = fold
self.xargs = kw
def run(self):
while True:
(cexp,gexp) = self.job_queue.get()
if cexp is WorkerStopToken:
self.job_queue.put((cexp,gexp))
# print 'worker %s stop.' % self.name
break
rate = xvalidate(mlpy.LibSvm,self.fold,self.x,self.y,
gamma=2.0**gexp,C=2.0**cexp,**self.xargs)
self.result_queue.put((self.name,cexp,gexp,rate))
# The gridsearch() function and its auxiliaries are based on code
# from libSVM and the following copyright notice applies thereto:
#
## Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
## All rights reserved.
##
## Redistribution and use in source and binary forms, with or without
## modification, are permitted provided that the following conditions
## are met:
##
## 1. Redistributions of source code must retain the above copyright
## notice, this list of conditions and the following disclaimer.
##
## 2. Redistributions in binary form must reproduce the above copyright
## notice, this list of conditions and the following disclaimer in the
## documentation and/or other materials provided with the distribution.
##
## 3. Neither name of copyright holders nor the names of its contributors
## may be used to endorse or promote products derived from this software
## without specific prior written permission.
##
##
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
## ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
## CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
## EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
## PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
## PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
## LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
## NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
## SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[docs]def gridsearch(x,y,fold,
crange=( -5, 15, 2),
grange=( 3, -15, -2),
nr_local_worker = 1,
gnuplot=None,verbosity=0,**kw):
"""
Perform a grid search based on the given test set and options.
If gnuplot is given, it should implement the interface of GnuPlot
object, and will be used to plot and update a contour plot of
parameter choices during the search.
The return values is (db,(best_c1, best_g1, best_rate)),
where db is a list of all the tested parameters with performance,
best_c1 and best_g1 are the optimal parameters (log c and log gamma),
and best_rate is the cross-validation accuracy at the optimised
parameters.
This code is derived from the grid search script distributed
with the libSVM library.
Copyright (c) 2000-2010 Chih-Chung Chang and Chih-Jen Lin
Copyright (c) 2012 Hans Georg Schaathun
"""
# put jobs in queue
jobs = calculate_jobs(crange,grange)
job_queue = queue.Queue(0)
result_queue = queue.Queue(0)
for line in jobs:
for j in line:
job_queue.put(j)
job_queue._put = job_queue.queue.appendleft
# fire local workers
for i in range(nr_local_worker):
Worker('local',job_queue,result_queue,x,y,
fold=fold,**kw).start()
# gather results
done_jobs = {}
db = []
best_rate = -1
best_c1,best_g1 = None,None
for line in jobs:
for (c,g) in line:
while (c, g) not in done_jobs:
(worker,c1,g1,rate) = result_queue.get()
done_jobs[(c1,g1)] = rate
if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):
best_rate = rate
best_c1,best_g1=c1,g1
best_c = 2.0**c1
best_g = 2.0**g1
if verbosity > 2:
print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
(worker,c1,g1,rate, best_c, best_g, best_rate))
db.append((c,g,done_jobs[(c,g)]))
if gnuplot != None:
gnuplot.redraw(db,[best_c1, best_g1, best_rate])
job_queue.put((WorkerStopToken,None))
return (db,(best_c1, best_g1),(best_c, best_g), best_rate)
[docs]def xvalidate(cls,k,x,y,seed=0,**kw):
"""Cross-validate a model
:Parameters:
cls : class of the classifier to validate
k : integer
to perform k-fold crossvalidation
x : 2-D array
training set feature vectors
y : 1-D array
training set labels
All other keyword arguments are passed to the classifier constructor.
"""
x = np.asarray(x,dtype=np.float)
(N,D) = x.shape
idx = mlpy.cv_kfold(N,k,seed=seed)
R = 0.0
for tr,ts in idx:
C = cls(**kw)
C.learn( x[tr,:], y[tr] )
P = C.pred( x[ts,:] )
check = ( P == y[ts] )
R += np.sum(check)
return R / N