Source code for pysteg.sql.errors

## -*- coding: utf-8 -*-
## (C) 2012: Hans Georg Schaathun <georg@schaathun.net> 

"""
Error profiling for steganalysers.  Very experimental and undocumented.
"""

from .coverselect import predict
import matplotlib.pyplot as plt
from sqlobject import SQLObject
from . import *
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

[docs]class ErrorProfiler(object): def __init__(self,L): """ Given a list of (TestSet,SVModel) pairs L, return a list of of tuples (fn,label,s1,s2,...) where fn in the file basename, label is the class label, and s1, s2, ... are the classification scores for each SVModel in the order they occur in L. """ self.models = [ m for (s,m) in L ] Q = [ dict([ ( im.getBasename(), (im.label, im.getOneFeature(mod)) ) for im in S ]) for (S,mod) in L ] names = Q[0].keys() R = [] for n in names: F = [ q[n] for q in Q ] label = F[0][0] for l,f in F: assert label == l, "Test sets are not comparable" R.append( (n,label) + tuple( [ f for (l,f) in F] ) ) self.plist = R
[docs] def cat( self, k1, k2 ): L = self.pairProfile( k1, k2 ) both = [ a for (a,x,y) in L if (x,y) == (True,True) ] one = [ a for (a,x,y) in L if (x,y) == (True,False) ] two = [ a for (a,x,y) in L if (x,y) == (False,True) ] neither = [ a for (a,x,y) in L if (x,y) == (False,False) ] return (both,one,two,neither)
[docs] def pie( self, outfile, k1, k2 ): L = [ (x,y) for (a,x,y) in self.pairProfile( k1, k2 ) ] N = len(L) both = L.count( (True,True) ) one = L.count( (True,False) ) two = L.count( (False,True) ) neither = L.count( (False,False) ) print float(both)/N, float(one)/N, float(two)/N, float(neither)/N plt.pie( [both,one,two,neither], labels=[ "both", k1, k2, "neither" ] ) plt.savefig( outfile )
[docs] def pairProfile( self, k1, k2 ): idx1 = self.models.index(k1) + 2 idx2 = self.models.index(k2) + 2 L = [ (q[0],q[1] == predict(q[idx1]),q[1] == predict(q[idx2])) for q in self.plist ] return L
[docs]def getFeatures( imgset, L, f1, f2 ): if not isinstance(imgset,SQLObject): imgset = ImageSet.byName(imgset) for n in L: im = imgset.getBasename(n) yield (n,im.getOneFeature(f1),im.getOneFeature(f2))
[docs]class Img(dict): def __init__(self,im): self.label = im.label self.name = im.getBasename() self.image = im
[docs] def loadCoverFeatures(self,L): if isinstance(L,str): L = [ L ] for k in L: self[k] = self.image.getCoverFeature(k) return
[docs] def loadFeatures(self,L): if isinstance(L,str): L = [ L ] for k in L: self[k] = self.image.getOneFeature(k) return
[docs] def eType(self,key): if self.label == 1 and predict(self[key]) == 1: return "TP" elif self.label == 0 and predict(self[key]) == 1: return "FP" elif self.label == 1 and predict(self[key]) == 0: return "FN" elif self.label == 0 and predict(self[key]) == 0: return "TN" else: raise TypeError, "Only 0/1 labels are supported."
[docs]class ImgList(list): """ This class represents a list of images with feature values downloaded from the SQL server and managed in local memory. """ def __init__(self,imgset=None): if imgset != None: if not isinstance(imgset,SQLObject): imgset = TestSet.byName(imgset) for im in imgset: self.append( Img(im) )
[docs] def loadCoverFeatures(self,L): for im in self: im.loadCoverFeatures(L) return
[docs] def loadFeatures(self,L): for im in self: im.loadFeatures(L) return
[docs] def get(self,key,score=None,ecat=None): if ecat == None: return [ im[key] for im in self ] else: return [ im[key] for im in self if im.eType(score) == ecat ]
[docs] def histogram(self,key,bins=5): "Make a histogram of the feature values given by feature key." return np.histogram( self.get(key), bins=bins )
[docs] def histogram2d(self,k1,k2,bins=5,score=None,ecat=None): """Make a 2D histogram of the feature values given by the two features k1 and k2.""" return np.histogram2d( self.get(k1,score,ecat), self.get(k2,score,ecat), bins=bins )
[docs] def getBars(self,k1,k2,key=None,bins=5): print "[getBars] len =", len(self) (h,xe,ye) = self.histogram2d(k1,k2,bins) R = { "total" : h, "x" : xe, "y" : ye, } if key != None: (FP,xx,yy) = self.histogram2d(k1,k2,bins=(xe,ye),score=key,ecat="FP") (TP,xx,yy) = self.histogram2d(k1,k2,bins=(xe,ye),score=key,ecat="TP") (FN,xx,yy) = self.histogram2d(k1,k2,bins=(xe,ye),score=key,ecat="FN") (TN,xx,yy) = self.histogram2d(k1,k2,bins=(xe,ye),score=key,ecat="TN") R["FP"] = FP R["TP"] = TP R["FN"] = FN R["TN"] = TN tst = FP+FN+TP+TN-h print xe, ye #print "test", np.min(tst), np.max(tst) print "[getBars] return" return R
[docs] def bar3d(self,outfile=None,**kw): R = self.getBars(**kw) fig = plt.figure() ax = Axes3D(fig) if R.has_key("TP"): A = (R["TP"] + R["TN"]).astype(float) A /= R["total"] else: A = R["total"] (xe,ye) = R["x"],R["y"] (X,Y) = np.meshgrid(xe[:-1],ye[:-1]) (dx,dy) = xe[1:]-xe[:-1],ye[1:]-ye[:-1] (DX,DY) = np.meshgrid(dx,dy) A[ ( np.isnan(A) ) ] = 0 A = A.transpose() ax.bar3d( X.flatten(), Y.flatten(), np.zeros_like(X.flatten()), 0.75*DX.flatten(), 0.75*DY.flatten(), A.flatten(), ) if outfile != None: plt.savefig(outfile) return fig
[docs] def erates(self,key): L = [ (im.label,predict(im[key])) for im in self ] TP = L.count((1,1)) FP = L.count((0,1)) TN = L.count((0,0)) FN = L.count((1,0)) return (float(FP)/(TN+FP), float(FN)/(TP+FN), float(TP+TN)/(TN+TP+FP+FN))
def _getBin(x,edges): r = 0 while x >= edges[r]: # print r, x, edges[r], len(edges) r += 1 if r == len(edges): break return r def _getBin2d(x,y,xedges,yedges): return (_getBin(x,xedges),_getBin(y,yedges))
[docs]def scatterPlot( outfile, imgset, L, f1, f2 ): R = [ getFeatures(imgset,s,f1,f2) for s in L ] C = [ "b", "r", "y", "g", ] N = len(R) scatteropt = { "s" : 1, "linewidth" : (0,) } for i in range(N): Q = list(R[i]) plt.scatter( [ x for (n,x,y) in Q ], [ y for (n,x,y) in Q ], c=C[i], **scatteropt ) plt.savefig( outfile )