The attached file ( ransac.py ) implements the RANSAC algorithm. An example image:
To run the file, save it to your computer, start IPython
ipython -wthread
Import the module and run the test program
To use the module you need to create a model class with two methods
An example of such model is the class LinearLeastSquaresModel as seen the file source (below)
1 import numpy
2 import scipy # use numpy if scipy unavailable
3 import scipy.linalg # use numpy if scipy unavailable
4
5 ## Copyright (c) 2004-2007, Andrew D. Straw. All rights reserved.
6
7 ## Redistribution and use in source and binary forms, with or without
8 ## modification, are permitted provided that the following conditions are
9 ## met:
10
11 ## * Redistributions of source code must retain the above copyright
12 ## notice, this list of conditions and the following disclaimer.
13
14 ## * Redistributions in binary form must reproduce the above
15 ## copyright notice, this list of conditions and the following
16 ## disclaimer in the documentation and/or other materials provided
17 ## with the distribution.
18
19 ## * Neither the name of the Andrew D. Straw nor the names of its
20 ## contributors may be used to endorse or promote products derived
21 ## from this software without specific prior written permission.
22
23 ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 ## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 ## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 ## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27 ## OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28 ## SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29 ## LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30 ## DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31 ## THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32 ## (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 ## OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
35 def ransac(data,model,n,k,t,d,debug=False,return_all=False):
36 """fit model parameters to data using the RANSAC algorithm
37
38 This implementation written from pseudocode found at
39 http://en.wikipedia.org/w/index.php?title=RANSAC&oldid=116358182
40
41 {{{
42 Given:
43 data - a set of observed data points
44 model - a model that can be fitted to data points
45 n - the minimum number of data values required to fit the model
46 k - the maximum number of iterations allowed in the algorithm
47 t - a threshold value for determining when a data point fits a model
48 d - the number of close data values required to assert that a model fits well to data
49 Return:
50 bestfit - model parameters which best fit the data (or nil if no good model is found)
51 iterations = 0
52 bestfit = nil
53 besterr = something really large
54 while iterations < k {
55 maybeinliers = n randomly selected values from data
56 maybemodel = model parameters fitted to maybeinliers
57 alsoinliers = empty set
58 for every point in data not in maybeinliers {
59 if point fits maybemodel with an error smaller than t
60 add point to alsoinliers
61 }
62 if the number of elements in alsoinliers is > d {
63 % this implies that we may have found a good model
64 % now test how good it is
65 bettermodel = model parameters fitted to all points in maybeinliers and alsoinliers
66 thiserr = a measure of how well model fits these points
67 if thiserr < besterr {
68 bestfit = bettermodel
69 besterr = thiserr
70 }
71 }
72 increment iterations
73 }
74 return bestfit
75 }}}
76 """
77 iterations = 0
78 bestfit = None
79 besterr = numpy.inf
80 best_inlier_idxs = None
81 while iterations < k:
82 maybe_idxs, test_idxs = random_partition(n,data.shape[0])
83 maybeinliers = data[maybe_idxs,:]
84 test_points = data[test_idxs]
85 maybemodel = model.fit(maybeinliers)
86 test_err = model.get_error( test_points, maybemodel)
87 also_idxs = test_idxs[test_err < t] # select indices of rows with accepted points
88 alsoinliers = data[also_idxs,:]
89 if debug:
90 print 'test_err.min()',test_err.min()
91 print 'test_err.max()',test_err.max()
92 print 'numpy.mean(test_err)',numpy.mean(test_err)
93 print 'iteration %d:len(alsoinliers) = %d'%(
94 iterations,len(alsoinliers))
95 if len(alsoinliers) > d:
96 betterdata = numpy.concatenate( (maybeinliers, alsoinliers) )
97 bettermodel = model.fit(betterdata)
98 better_errs = model.get_error( betterdata, bettermodel)
99 thiserr = numpy.mean( better_errs )
100 if thiserr < besterr:
101 bestfit = bettermodel
102 besterr = thiserr
103 best_inlier_idxs = numpy.concatenate( (maybe_idxs, also_idxs) )
104 iterations+=1
105 if bestfit is None:
106 raise ValueError("did not meet fit acceptance criteria")
107 if return_all:
108 return bestfit, {'inliers':best_inlier_idxs}
109 else:
110 return bestfit
111
112 def random_partition(n,n_data):
113 """return n random rows of data (and also the other len(data)-n rows)"""
114 all_idxs = numpy.arange( n_data )
115 numpy.random.shuffle(all_idxs)
116 idxs1 = all_idxs[:n]
117 idxs2 = all_idxs[n:]
118 return idxs1, idxs2
119
120 class LinearLeastSquaresModel:
121 """linear system solved using linear least squares
122
123 This class serves as an example that fulfills the model interface
124 needed by the ransac() function.
125
126 """
127 def __init__(self,input_columns,output_columns,debug=False):
128 self.input_columns = input_columns
129 self.output_columns = output_columns
130 self.debug = debug
131 def fit(self, data):
132 A = numpy.vstack([data[:,i] for i in self.input_columns]).T
133 B = numpy.vstack([data[:,i] for i in self.output_columns]).T
134 x,resids,rank,s = scipy.linalg.lstsq(A,B)
135 return x
136 def get_error( self, data, model):
137 A = numpy.vstack([data[:,i] for i in self.input_columns]).T
138 B = numpy.vstack([data[:,i] for i in self.output_columns]).T
139 B_fit = scipy.dot(A,model)
140 err_per_point = numpy.sum((B-B_fit)**2,axis=1) # sum squared error per row
141 return err_per_point
142
143 def test():
144 # generate perfect input data
145
146 n_samples = 500
147 n_inputs = 1
148 n_outputs = 1
149 A_exact = 20*numpy.random.random((n_samples,n_inputs) )
150 perfect_fit = 60*numpy.random.normal(size=(n_inputs,n_outputs) ) # the model
151 B_exact = scipy.dot(A_exact,perfect_fit)
152 assert B_exact.shape == (n_samples,n_outputs)
153
154 # add a little gaussian noise (linear least squares alone should handle this well)
155 A_noisy = A_exact + numpy.random.normal(size=A_exact.shape )
156 B_noisy = B_exact + numpy.random.normal(size=B_exact.shape )
157
158 if 1:
159 # add some outliers
160 n_outliers = 100
161 all_idxs = numpy.arange( A_noisy.shape[0] )
162 numpy.random.shuffle(all_idxs)
163 outlier_idxs = all_idxs[:n_outliers]
164 non_outlier_idxs = all_idxs[n_outliers:]
165 A_noisy[outlier_idxs] = 20*numpy.random.random((n_outliers,n_inputs) )
166 B_noisy[outlier_idxs] = 50*numpy.random.normal(size=(n_outliers,n_outputs) )
167
168 # setup model
169
170 all_data = numpy.hstack( (A_noisy,B_noisy) )
171 input_columns = range(n_inputs) # the first columns of the array
172 output_columns = [n_inputs+i for i in range(n_outputs)] # the last columns of the array
173 debug = False
174 model = LinearLeastSquaresModel(input_columns,output_columns,debug=debug)
175
176 linear_fit,resids,rank,s = scipy.linalg.lstsq(all_data[:,input_columns],
177 all_data[:,output_columns])
178
179 # run RANSAC algorithm
180 ransac_fit, ransac_data = ransac(all_data,model,
181 50, 1000, 7e3, 300, # misc. parameters
182 debug=debug,return_all=True)
183 if 1:
184 import pylab
185
186 sort_idxs = numpy.argsort(A_exact[:,0])
187 A_col0_sorted = A_exact[sort_idxs] # maintain as rank-2 array
188
189 if 1:
190 pylab.plot( A_noisy[:,0], B_noisy[:,0], 'k.', label='data' )
191 pylab.plot( A_noisy[ransac_data['inliers'],0], B_noisy[ransac_data['inliers'],0], 'bx', label='RANSAC data' )
192 else:
193 pylab.plot( A_noisy[non_outlier_idxs,0], B_noisy[non_outlier_idxs,0], 'k.', label='noisy data' )
194 pylab.plot( A_noisy[outlier_idxs,0], B_noisy[outlier_idxs,0], 'r.', label='outlier data' )
195 pylab.plot( A_col0_sorted[:,0],
196 numpy.dot(A_col0_sorted,ransac_fit)[:,0],
197 label='RANSAC fit' )
198 pylab.plot( A_col0_sorted[:,0],
199 numpy.dot(A_col0_sorted,perfect_fit)[:,0],
200 label='exact system' )
201 pylab.plot( A_col0_sorted[:,0],
202 numpy.dot(A_col0_sorted,linear_fit)[:,0],
203 label='linear fit' )
204 pylab.legend()
205 pylab.show()
206
207 if __name__=='__main__':
208 test()
209