1 import numpy as N
2 import unittest
3 from numpy.testing import NumpyTestCase, assert_array_almost_equal, assert_almost_equal, assert_equal
4 import warnings
5
6 def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
7 """Generate a new array that chops the given array along the given axis into overlapping frames.
8
9 example:
10 >>> segment_axis(arange(10), 4, 2)
11 array([[0, 1, 2, 3],
12 [2, 3, 4, 5],
13 [4, 5, 6, 7],
14 [6, 7, 8, 9]])
15
16 arguments:
17 a The array to segment
18 length The length of each frame
19 overlap The number of array elements by which the frames should overlap
20 axis The axis to operate on; if None, act on the flattened array
21 end What to do with the last frame, if the array is not evenly
22 divisible into pieces. Options are:
23
24 'cut' Simply discard the extra values
25 'wrap' Copy values from the beginning of the array
26 'pad' Pad with a constant value
27
28 endvalue The value to use for end='pad'
29
30 The array is not copied unless necessary (either because it is
31 unevenly strided and being flattened or because end is set to
32 'pad' or 'wrap').
33 """
34
35 if axis is None:
36 a = N.ravel(a)
37 axis = 0
38
39 l = a.shape[axis]
40
41 if overlap>=length:
42 raise ValueError, "frames cannot overlap by more than 100%"
43 if overlap<0 or length<=0:
44 raise ValueError, "overlap must be nonnegative and length must be positive"
45
46 if l<length or (l-length)%(length-overlap):
47 if l>length:
48 roundup = length + (1+(l-length)//(length-overlap))*(length-overlap)
49 rounddown = length + ((l-length)//(length-overlap))*(length-overlap)
50 else:
51 roundup = length
52 rounddown = 0
53 assert rounddown<l<roundup
54 assert roundup==rounddown+(length-overlap) or (roundup==length and rounddown==0)
55 a = a.swapaxes(-1,axis)
56
57 if end=='cut':
58 a = a[...,:rounddown]
59 elif end in ['pad','wrap']:
60 s = list(a.shape)
61 s[-1]=roundup
62 b = N.empty(s,dtype=a.dtype)
63 b[...,:l] = a
64 if end=='pad':
65 b[...,l:] = endvalue
66 elif end=='wrap':
67 b[...,l:] = a[...,:roundup-l]
68 a = b
69
70 a = a.swapaxes(-1,axis)
71
72
73 l = a.shape[axis]
74 if l==0:
75 raise ValueError, "Not enough data points to segment array in 'cut' mode; try 'pad' or 'wrap'"
76 assert l>=length
77 assert (l-length)%(length-overlap) == 0
78 n = 1+(l-length)//(length-overlap)
79 s = a.strides[axis]
80 newshape = a.shape[:axis]+(n,length)+a.shape[axis+1:]
81 newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
82
83 try:
84 return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
85 except TypeError:
86 warnings.warn("Problem with ndarray creation forces copy.")
87 a = a.copy()
88
89 newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
90 return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
91
92
93
94 class TestSegment(NumpyTestCase):
95 def test_simple(self):
96 assert_equal(segment_axis(N.arange(6),length=3,overlap=0),
97 N.array([[0,1,2],[3,4,5]]))
98
99 assert_equal(segment_axis(N.arange(7),length=3,overlap=1),
100 N.array([[0,1,2],[2,3,4],[4,5,6]]))
101
102 assert_equal(segment_axis(N.arange(7),length=3,overlap=2),
103 N.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5],[4,5,6]]))
104
105 def test_error_checking(self):
106 self.assertRaises(ValueError,
107 lambda: segment_axis(N.arange(7),length=3,overlap=-1))
108 self.assertRaises(ValueError,
109 lambda: segment_axis(N.arange(7),length=0,overlap=0))
110 self.assertRaises(ValueError,
111 lambda: segment_axis(N.arange(7),length=3,overlap=3))
112 self.assertRaises(ValueError,
113 lambda: segment_axis(N.arange(7),length=8,overlap=3))
114
115 def test_ending(self):
116 assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='cut'),
117 N.array([[0,1,2],[2,3,4]]))
118 assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='wrap'),
119 N.array([[0,1,2],[2,3,4],[4,5,0]]))
120 assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='pad',endvalue=-17),
121 N.array([[0,1,2],[2,3,4],[4,5,-17]]))
122
123 def test_multidimensional(self):
124
125 assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=3,length=3,overlap=1).shape,
126 (2,3,4,2,3,6))
127
128 assert_equal(segment_axis(N.ones((2,5,4,3,6)).swapaxes(1,3),axis=3,length=3,overlap=1).shape,
129 (2,3,4,2,3,6))
130
131 assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='cut').shape,
132 (2,3,1,3,5,6))
133
134 assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='wrap').shape,
135 (2,3,2,3,5,6))
136
137 assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='pad').shape,
138 (2,3,2,3,5,6))
139
140 if __name__=='__main__':
141 unittest.main()