1 import numpy
2 import scipy
3 import scipy.linalg
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 def ransac(data,model,n,k,t,d,debug=False,return_all=False):
36 """fit model parameters to data using the RANSAC algorithm
37
38 This implementation written from pseudocode found at
39 http://en.wikipedia.org/w/index.php?title=RANSAC&oldid=116358182
40
41 {{{
42 Given:
43 data - a set of observed data points
44 model - a model that can be fitted to data points
45 n - the minimum number of data values required to fit the model
46 k - the maximum number of iterations allowed in the algorithm
47 t - a threshold value for determining when a data point fits a model
48 d - the number of close data values required to assert that a model fits well to data
49 Return:
50 bestfit - model parameters which best fit the data (or nil if no good model is found)
51 iterations = 0
52 bestfit = nil
53 besterr = something really large
54 while iterations < k {
55 maybeinliers = n randomly selected values from data
56 maybemodel = model parameters fitted to maybeinliers
57 alsoinliers = empty set
58 for every point in data not in maybeinliers {
59 if point fits maybemodel with an error smaller than t
60 add point to alsoinliers
61 }
62 if the number of elements in alsoinliers is > d {
63 % this implies that we may have found a good model
64 % now test how good it is
65 bettermodel = model parameters fitted to all points in maybeinliers and alsoinliers
66 thiserr = a measure of how well model fits these points
67 if thiserr < besterr {
68 bestfit = bettermodel
69 besterr = thiserr
70 }
71 }
72 increment iterations
73 }
74 return bestfit
75 }}}
76 """
77 iterations = 0
78 bestfit = None
79 besterr = numpy.inf
80 best_inlier_idxs = None
81 while iterations < k:
82 maybe_idxs, test_idxs = random_partition(n,data.shape[0])
83 maybeinliers = data[maybe_idxs,:]
84 test_points = data[test_idxs]
85 maybemodel = model.fit(maybeinliers)
86 test_err = model.get_error( test_points, maybemodel)
87 also_idxs = test_idxs[test_err < t]
88 alsoinliers = data[also_idxs,:]
89 if debug:
90 print 'test_err.min()',test_err.min()
91 print 'test_err.max()',test_err.max()
92 print 'numpy.mean(test_err)',numpy.mean(test_err)
93 print 'iteration %d:len(alsoinliers) = %d'%(
94 iterations,len(alsoinliers))
95 if len(alsoinliers) > d:
96 betterdata = numpy.concatenate( (maybeinliers, alsoinliers) )
97 bettermodel = model.fit(betterdata)
98 better_errs = model.get_error( betterdata, bettermodel)
99 thiserr = numpy.mean( better_errs )
100 if thiserr < besterr:
101 bestfit = bettermodel
102 besterr = thiserr
103 best_inlier_idxs = numpy.concatenate( (maybe_idxs, also_idxs) )
104 iterations+=1
105 if bestfit is None:
106 raise ValueError("did not meet fit acceptance criteria")
107 if return_all:
108 return bestfit, {'inliers':best_inlier_idxs}
109 else:
110 return bestfit
111
112 def random_partition(n,n_data):
113 """return n random rows of data (and also the other len(data)-n rows)"""
114 all_idxs = numpy.arange( n_data )
115 numpy.random.shuffle(all_idxs)
116 idxs1 = all_idxs[:n]
117 idxs2 = all_idxs[n:]
118 return idxs1, idxs2
119
120 class LinearLeastSquaresModel:
121 """linear system solved using linear least squares
122
123 This class serves as an example that fulfills the model interface
124 needed by the ransac() function.
125
126 """
127 def __init__(self,input_columns,output_columns,debug=False):
128 self.input_columns = input_columns
129 self.output_columns = output_columns
130 self.debug = debug
131 def fit(self, data):
132 A = numpy.vstack([data[:,i] for i in self.input_columns]).T
133 B = numpy.vstack([data[:,i] for i in self.output_columns]).T
134 x,resids,rank,s = scipy.linalg.lstsq(A,B)
135 return x
136 def get_error( self, data, model):
137 A = numpy.vstack([data[:,i] for i in self.input_columns]).T
138 B = numpy.vstack([data[:,i] for i in self.output_columns]).T
139 B_fit = scipy.dot(A,model)
140 err_per_point = numpy.sum((B-B_fit)**2,axis=1)
141 return err_per_point
142
143 def test():
144
145
146 n_samples = 500
147 n_inputs = 1
148 n_outputs = 1
149 A_exact = 20*numpy.random.random((n_samples,n_inputs) )
150 perfect_fit = 60*numpy.random.normal(size=(n_inputs,n_outputs) )
151 B_exact = scipy.dot(A_exact,perfect_fit)
152 assert B_exact.shape == (n_samples,n_outputs)
153
154
155 A_noisy = A_exact + numpy.random.normal(size=A_exact.shape )
156 B_noisy = B_exact + numpy.random.normal(size=B_exact.shape )
157
158 if 1:
159
160 n_outliers = 100
161 all_idxs = numpy.arange( A_noisy.shape[0] )
162 numpy.random.shuffle(all_idxs)
163 outlier_idxs = all_idxs[:n_outliers]
164 non_outlier_idxs = all_idxs[n_outliers:]
165 A_noisy[outlier_idxs] = 20*numpy.random.random((n_outliers,n_inputs) )
166 B_noisy[outlier_idxs] = 50*numpy.random.normal(size=(n_outliers,n_outputs) )
167
168
169
170 all_data = numpy.hstack( (A_noisy,B_noisy) )
171 input_columns = range(n_inputs)
172 output_columns = [n_inputs+i for i in range(n_outputs)]
173 debug = False
174 model = LinearLeastSquaresModel(input_columns,output_columns,debug=debug)
175
176 linear_fit,resids,rank,s = scipy.linalg.lstsq(all_data[:,input_columns],
177 all_data[:,output_columns])
178
179
180 ransac_fit, ransac_data = ransac(all_data,model,
181 50, 1000, 7e3, 300,
182 debug=debug,return_all=True)
183 if 1:
184 import pylab
185
186 sort_idxs = numpy.argsort(A_exact[:,0])
187 A_col0_sorted = A_exact[sort_idxs]
188
189 if 1:
190 pylab.plot( A_noisy[:,0], B_noisy[:,0], 'k.', label='data' )
191 pylab.plot( A_noisy[ransac_data['inliers'],0], B_noisy[ransac_data['inliers'],0], 'bx', label='RANSAC data' )
192 else:
193 pylab.plot( A_noisy[non_outlier_idxs,0], B_noisy[non_outlier_idxs,0], 'k.', label='noisy data' )
194 pylab.plot( A_noisy[outlier_idxs,0], B_noisy[outlier_idxs,0], 'r.', label='outlier data' )
195 pylab.plot( A_col0_sorted[:,0],
196 numpy.dot(A_col0_sorted,ransac_fit)[:,0],
197 label='RANSAC fit' )
198 pylab.plot( A_col0_sorted[:,0],
199 numpy.dot(A_col0_sorted,perfect_fit)[:,0],
200 label='exact system' )
201 pylab.plot( A_col0_sorted[:,0],
202 numpy.dot(A_col0_sorted,linear_fit)[:,0],
203 label='linear fit' )
204 pylab.legend()
205 pylab.show()
206
207 if __name__=='__main__':
208 test()
209