1 import numpy as npy
2 from scipy.weave import inline
3 from numpy.testing import assert_array_almost_equal, assert_almost_equal
4
5 def prodsum(a, b, axis=None):
6 assert a.shape == b.shape, "cannot take prodsum of different size arrays"
7 nd = len(a.shape)
8 if axis is not None:
9 caxis = axis if axis >=0 else nd + axis
10 assert caxis < nd, "cannot perform operation in this axis: %d"%axis
11 dims = list(a.shape)
12 dims.pop(caxis)
13 c = npy.zeros(tuple(dims), npy.float64)
14 else:
15 caxis = -1
16 c = npy.array([0.0])
17
18 xtra = \
19 """
20 double prodsum(double *d1, double *d2, int stride, int size)
21 {
22 double s = 0.0;
23 while(size--) {
24 s += (*d1) * (*d2);
25 d1 += stride;
26 d2 += stride;
27 }
28 return s;
29 }
30 """
31
32 code = \
33 """
34 double *d1, *d2, *d3;
35 int sumall = caxis < 0 ? 1 : 0;
36 PyArrayIterObject *itr1, *itr2, *itr3;
37 itr1 = (PyArrayIterObject *) PyArray_IterAllButAxis(py_a, &caxis);
38 itr2 = (PyArrayIterObject *) PyArray_IterAllButAxis(py_b, &caxis);
39 if(!sumall) itr3 = (PyArrayIterObject *) PyArray_IterNew(py_c);
40 // make use of auto defined arrays
41 int stride = Sa[caxis]/sizeof(double);
42 int size = Na[caxis];
43 while( PyArray_ITER_NOTDONE(itr1) ) {
44 d1 = (double *) itr1->dataptr;
45 d2 = (double *) itr2->dataptr;
46 if(sumall) {
47 d3 = c;
48 } else {
49 d3 = (double *) itr3->dataptr;
50 PyArray_ITER_NEXT(itr3);
51 }
52 *d3 += prodsum(d1, d2, stride, size);
53 PyArray_ITER_NEXT(itr1);
54 PyArray_ITER_NEXT(itr2);
55 }
56 """
57 inline(code, ['a', 'b', 'c', 'caxis'], compiler='gcc',
58 support_code=xtra)
59 return c[0] if axis is None else c
60
61
62 def tests():
63 a = npy.random.rand(4,2,9)
64 b = npy.ones_like(a)
65 assert_almost_equal(prodsum(a,b), a.sum())
66 assert_array_almost_equal(prodsum(a,b,axis=-1), a.sum(axis=-1))
67 assert_array_almost_equal(prodsum(a[:2,:,1::2], b[:2,:,1::2], axis=0),
68 a[:2,:,1::2].sum(axis=0))
69 assert_array_almost_equal(prodsum(a[:,:,::-1], b[:,:,::-1], axis=-1),
70 a[:,:,::-1].sum(axis=-1))
71 print "all passed"