This is an archival dump of old wiki content --- see scipy.org for current material.
Please see http://scipy-cookbook.readthedocs.org/

Attachment 'segmentaxis.py'

Download

   1 import numpy as N
   2 import unittest
   3 from numpy.testing import NumpyTestCase, assert_array_almost_equal,             assert_almost_equal, assert_equal
   4 import warnings
   5 
   6 def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
   7     """Generate a new array that chops the given array along the given axis into overlapping frames.
   8 
   9     example:
  10     >>> segment_axis(arange(10), 4, 2)
  11     array([[0, 1, 2, 3],
  12            [2, 3, 4, 5],
  13            [4, 5, 6, 7],
  14            [6, 7, 8, 9]])
  15 
  16     arguments:
  17     a       The array to segment
  18     length  The length of each frame
  19     overlap The number of array elements by which the frames should overlap
  20     axis    The axis to operate on; if None, act on the flattened array
  21     end     What to do with the last frame, if the array is not evenly
  22             divisible into pieces. Options are:
  23 
  24             'cut'   Simply discard the extra values
  25             'wrap'  Copy values from the beginning of the array
  26             'pad'   Pad with a constant value
  27 
  28     endvalue    The value to use for end='pad'
  29 
  30     The array is not copied unless necessary (either because it is 
  31     unevenly strided and being flattened or because end is set to 
  32     'pad' or 'wrap').
  33     """
  34 
  35     if axis is None:
  36         a = N.ravel(a) # may copy
  37         axis = 0
  38 
  39     l = a.shape[axis]
  40 
  41     if overlap>=length:
  42         raise ValueError, "frames cannot overlap by more than 100%"
  43     if overlap<0 or length<=0:
  44         raise ValueError, "overlap must be nonnegative and length must be positive"
  45 
  46     if l<length or (l-length)%(length-overlap):
  47         if l>length:
  48             roundup = length + (1+(l-length)//(length-overlap))*(length-overlap)
  49             rounddown = length + ((l-length)//(length-overlap))*(length-overlap)
  50         else:
  51             roundup = length
  52             rounddown = 0
  53         assert rounddown<l<roundup
  54         assert roundup==rounddown+(length-overlap) or (roundup==length and rounddown==0)
  55         a = a.swapaxes(-1,axis)
  56 
  57         if end=='cut':
  58             a = a[...,:rounddown]
  59         elif end in ['pad','wrap']: # copying will be necessary
  60             s = list(a.shape)
  61             s[-1]=roundup
  62             b = N.empty(s,dtype=a.dtype)
  63             b[...,:l] = a
  64             if end=='pad':
  65                 b[...,l:] = endvalue
  66             elif end=='wrap':
  67                 b[...,l:] = a[...,:roundup-l]
  68             a = b
  69         
  70         a = a.swapaxes(-1,axis)
  71 
  72 
  73     l = a.shape[axis]
  74     if l==0:
  75         raise ValueError, "Not enough data points to segment array in 'cut' mode; try 'pad' or 'wrap'"
  76     assert l>=length
  77     assert (l-length)%(length-overlap) == 0
  78     n = 1+(l-length)//(length-overlap)
  79     s = a.strides[axis]
  80     newshape = a.shape[:axis]+(n,length)+a.shape[axis+1:]
  81     newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
  82 
  83     try: 
  84         return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
  85     except TypeError:
  86         warnings.warn("Problem with ndarray creation forces copy.")
  87         a = a.copy()
  88         # Shape doesn't change but strides does
  89         newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
  90         return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
  91         
  92 
  93 
  94 class TestSegment(NumpyTestCase):
  95     def test_simple(self):
  96         assert_equal(segment_axis(N.arange(6),length=3,overlap=0),
  97                          N.array([[0,1,2],[3,4,5]]))
  98 
  99         assert_equal(segment_axis(N.arange(7),length=3,overlap=1),
 100                          N.array([[0,1,2],[2,3,4],[4,5,6]]))
 101 
 102         assert_equal(segment_axis(N.arange(7),length=3,overlap=2),
 103                          N.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5],[4,5,6]]))
 104 
 105     def test_error_checking(self):
 106         self.assertRaises(ValueError,
 107                 lambda: segment_axis(N.arange(7),length=3,overlap=-1))
 108         self.assertRaises(ValueError,
 109                 lambda: segment_axis(N.arange(7),length=0,overlap=0))
 110         self.assertRaises(ValueError,
 111                 lambda: segment_axis(N.arange(7),length=3,overlap=3))
 112         self.assertRaises(ValueError,
 113                 lambda: segment_axis(N.arange(7),length=8,overlap=3))
 114 
 115     def test_ending(self):
 116         assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='cut'),
 117                          N.array([[0,1,2],[2,3,4]]))
 118         assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='wrap'),
 119                          N.array([[0,1,2],[2,3,4],[4,5,0]]))
 120         assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='pad',endvalue=-17),
 121                          N.array([[0,1,2],[2,3,4],[4,5,-17]]))
 122 
 123     def test_multidimensional(self):
 124         
 125         assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=3,length=3,overlap=1).shape,
 126                      (2,3,4,2,3,6))
 127 
 128         assert_equal(segment_axis(N.ones((2,5,4,3,6)).swapaxes(1,3),axis=3,length=3,overlap=1).shape,
 129                      (2,3,4,2,3,6))
 130 
 131         assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='cut').shape,
 132                      (2,3,1,3,5,6))
 133 
 134         assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='wrap').shape,
 135                      (2,3,2,3,5,6))
 136 
 137         assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='pad').shape,
 138                      (2,3,2,3,5,6))
 139 
 140 if __name__=='__main__':
 141     unittest.main()

New Attachment

File to upload
Rename to
Overwrite existing attachment of same name

Attached Files

To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.