Attachment 'ndarray.h'
Download
Toggle line numbers
1 //==========================================================
2 // ndarray.h: C++ interface for easy access to numpy arrays
3 //
4 // J. De Ridder
5 //==========================================================
6
7
8 #ifndef NDARRAY_H
9 #define NDARRAY_H
10
11
12
13 // The C-struct to retrieve the ctypes structure.
14 // Note: Order of struct members must be the same as in Python.
15 // Member names are not recognized!
16
17
18 template <typename T>
19 struct numpyArray
20 {
21 T *data;
22 long *shape;
23 long *strides;
24 };
25
26
27
28
29 // Traits need to used because the return type of the []-operator can be
30 // a subarray or an element, depending whether all the axes are exhausted.
31
32 template<typename datatype, int ndim> class Ndarray; // forward declaration
33
34
35 template<typename datatype, int ndim>
36 struct getItemTraits
37 {
38 typedef Ndarray<datatype, ndim-1> returnType;
39 };
40
41
42 template<typename datatype>
43 struct getItemTraits<datatype, 1>
44 {
45 typedef datatype& returnType;
46 };
47
48
49
50 // Ndarray definition
51
52 template<typename datatype, int ndim>
53 class Ndarray
54 {
55 private:
56 datatype *data;
57 long *shape;
58 long *strides;
59
60 public:
61 Ndarray(datatype *data, long *shape, long *strides);
62 Ndarray(const Ndarray<datatype, ndim>& array);
63 Ndarray(const numpyArray<datatype>& array);
64 long getShape(const int axis);
65 typename getItemTraits<datatype, ndim>::returnType operator[](unsigned long i);
66 };
67
68
69 // Ndarray constructor
70
71 template<typename datatype, int ndim>
72 Ndarray<datatype, ndim>::Ndarray(datatype *data, long *shape, long *strides)
73 {
74 this->data = data;
75 this->shape = shape;
76 this->strides = strides;
77 }
78
79
80 // Ndarray copy constructor
81
82 template<typename datatype, int ndim>
83 Ndarray<datatype, ndim>::Ndarray(const Ndarray<datatype, ndim>& array)
84 {
85 this->data = array.data;
86 this->shape = array.shape;
87 this->strides = array.strides;
88 }
89
90
91 // Ndarray constructor from ctypes structure
92
93 template<typename datatype, int ndim>
94 Ndarray<datatype, ndim>::Ndarray(const numpyArray<datatype>& array)
95 {
96 this->data = array.data;
97 this->shape = array.shape;
98 this->strides = array.strides;
99 }
100
101
102 // Ndarray method to get length of given axis
103
104 template<typename datatype, int ndim>
105 long Ndarray<datatype, ndim>::getShape(const int axis)
106 {
107 return this->shape[axis];
108 }
109
110
111
112 // Ndarray overloaded []-operator.
113 // The [i][j][k] selection is recursively replaced by i*strides[0]+j*strides[1]+k*strides[2]
114 // at compile time, using template meta-programming. If the axes are not exhausted, return
115 // a subarray, else return an element.
116
117 template<typename datatype, int ndim>
118 typename getItemTraits<datatype, ndim>::returnType
119 Ndarray<datatype, ndim>::operator[](unsigned long i)
120 {
121 return Ndarray<datatype, ndim-1>(&this->data[i*this->strides[0]], &this->shape[1], &this->strides[1]);
122 }
123
124
125
126 // Template partial specialisation of Ndarray.
127 // For 1D Ndarrays, the [] operator should return an element, not a subarray, so it needs
128 // to be special-cased. In principle only the operator[] method should be specialised, but
129 // for some reason my gcc version seems to require that then the entire class with all its
130 // methods are specialised.
131
132 template<typename datatype>
133 class Ndarray<datatype, 1>
134 {
135 private:
136 datatype *data;
137 long *shape;
138 long *strides;
139
140 public:
141 Ndarray(datatype *data, long *shape, long *strides);
142 Ndarray(const Ndarray<datatype, 1>& array);
143 Ndarray(const numpyArray<datatype>& array);
144 long getShape(const int axis);
145 typename getItemTraits<datatype, 1>::returnType operator[](unsigned long i);
146 };
147
148
149 // Ndarray partial specialised constructor
150
151 template<typename datatype>
152 Ndarray<datatype, 1>::Ndarray(datatype *data, long *shape, long *strides)
153 {
154 this->data = data;
155 this->shape = shape;
156 this->strides = strides;
157 }
158
159
160
161 // Ndarray partially specialised copy constructor
162
163 template<typename datatype>
164 Ndarray<datatype, 1>::Ndarray(const Ndarray<datatype, 1>& array)
165 {
166 this->data = array.data;
167 this->shape = array.shape;
168 this->strides = array.strides;
169 }
170
171
172
173 // Ndarray partially specialised constructor from ctypes structure
174
175 template<typename datatype>
176 Ndarray<datatype, 1>::Ndarray(const numpyArray<datatype>& array)
177 {
178 this->data = array.data;
179 this->shape = array.shape;
180 this->strides = array.strides;
181 }
182
183
184
185 // Ndarray method to get length of given axis
186
187 template<typename datatype>
188 long Ndarray<datatype, 1>::getShape(const int axis)
189 {
190 return this->shape[axis];
191 }
192
193
194
195 // Partial specialised [] operator: for 1D arrays, return an element rather than a subarray
196
197 template<typename datatype>
198 typename getItemTraits<datatype, 1>::returnType
199 Ndarray<datatype, 1>::operator[](unsigned long i)
200 {
201 return this->data[i*this->strides[0]];
202 }
203
204
205
206 #endif
207