source: CIVL/examples/omp/m4ri/tests/test_multiplication.c

main
Last change on this file was ea777aa, checked in by Alex Wilton <awilton@…>, 3 years ago

Moved examples, include, build_default.properties, common.xml, and README out from dev.civl.com into the root of the repo.

git-svn-id: svn://vsl.cis.udel.edu/civl/trunk@5704 fb995dde-84ed-4084-dfe6-e5aef3e2452c

  • Property mode set to 100644
File size: 8.9 KB
Line 
1#include <m4ri/config.h>
2#include <stdlib.h>
3#include <m4ri/m4ri.h>
4
5/**
6 * Check that the results of all implemented multiplication algorithms
7 * match up.
8 *
9 * \param m Number of rows of A
10 * \param l Number of columns of A/number of rows of B
11 * \param n Number of columns of B
12 * \param k Parameter k of M4RM algorithm, may be 0 for automatic choice.
13 * \param cutoff Cut off parameter at which dimension to switch from
14 * Strassen to M4RM
15 */
16int mul_test_equality(rci_t m, rci_t l, rci_t n, int k, int cutoff) {
17 int ret = 0;
18 printf(" mul: m: %4d, l: %4d, n: %4d, k: %2d, cutoff: %4d", m, l, n, k, cutoff);
19
20 /* we create two random matrices */
21 mzd_t *A = mzd_init(m, l);
22 mzd_t *B = mzd_init(l, n);
23 mzd_randomize(A);
24 mzd_randomize(B);
25
26 /* C = A*B via Strassen */
27 mzd_t *C = mzd_mul(NULL, A, B, cutoff);
28
29 /* D = A*B via M4RM, temporary buffers are managed internally */
30 mzd_t *D = mzd_mul_m4rm( NULL, A, B, k);
31
32 if (mzd_equal(C, D) != TRUE) {
33 printf(" Strassen != M4RM");
34 ret -=1;
35 }
36
37 /* E = A*B via naive cubic multiplication */
38 mzd_t *E = mzd_mul_naive( NULL, A, B);
39
40 if (mzd_equal(D, E) != TRUE) {
41 printf(" M4RM != Naiv");
42 ret -= 1;
43 }
44
45 if (mzd_equal(C, E) != TRUE) {
46 printf(" Strassen != Naiv");
47 ret -= 1;
48 }
49
50#if __M4RI_HAVE_OPENMP
51 mzd_t *F = mzd_mul_mp(NULL, A, B, cutoff);
52 if (mzd_equal(C, F) != TRUE) {
53 printf(" MP != Naiv");
54 ret -= 1;
55 }
56 mzd_free(F);
57#endif
58
59 mzd_free(A);
60 mzd_free(B);
61 mzd_free(C);
62 mzd_free(D);
63 mzd_free(E);
64
65 if(ret==0) {
66 printf(" ... passed\n");
67 } else {
68 printf(" ... FAILED\n");
69 }
70
71 return ret;
72
73}
74
75/**
76 * Check that the results of all implemented squaring algorithms match
77 * up.
78 *
79 * \param m Number of rows and columns of A
80 * \param k Parameter k of M4RM algorithm, may be 0 for automatic choice.
81 * \param cutoff Cut off parameter at which dimension to switch from
82 * Strassen to M4RM
83 */
84int sqr_test_equality(rci_t m, int k, int cutoff) {
85 int ret = 0;
86 mzd_t *A, *C, *D, *E;
87
88 printf(" sqr: m: %4d, k: %2d, cutoff: %4d", m, k, cutoff);
89
90 /* we create one random matrix */
91 A = mzd_init(m, m);
92 mzd_randomize(A);
93
94 /* C = A*A via Strassen */
95 C = mzd_mul(NULL, A, A, cutoff);
96
97 /* D = A*A via M4RM, temporary buffers are managed internally */
98 D = mzd_mul_m4rm( NULL, A, A, k);
99
100 /* E = A*A via naive cubic multiplication */
101 E = mzd_mul_naive( NULL, A, A);
102
103 mzd_free(A);
104
105 if (mzd_equal(C, D) != TRUE) {
106 printf(" Strassen != M4RM");
107 ret -=1;
108 }
109
110 if (mzd_equal(D, E) != TRUE) {
111 printf(" M4RM != Naiv");
112 ret -= 1;
113 }
114
115 if (mzd_equal(C, E) != TRUE) {
116 printf(" Strassen != Naiv");
117 ret -= 1;
118 }
119
120 mzd_free(C);
121 mzd_free(D);
122 mzd_free(E);
123
124 if(ret==0) {
125 printf(" ... passed\n");
126 } else {
127 printf(" ... FAILED\n");
128 }
129
130 return ret;
131}
132
133int addmul_test_equality(rci_t m, rci_t l, rci_t n, int k, int cutoff) {
134 int ret = 0;
135 printf("addmul: m: %4d, l: %4d, n: %4d, k: %2d, cutoff: %4d", m, l, n, k, cutoff);
136
137 /* we create two random matrices */
138 mzd_t *A = mzd_init(m, l);
139 mzd_t *B = mzd_init(l, n);
140 mzd_t *C = mzd_init(m, n);
141 mzd_randomize(A);
142 mzd_randomize(B);
143 mzd_randomize(C);
144
145 /* D = C + A*B via M4RM, temporary buffers are managed internally */
146 mzd_t *D = mzd_copy(NULL, C);
147 D = mzd_addmul_m4rm(D, A, B, k);
148
149 /* E = C + A*B via naiv cubic multiplication */
150 mzd_t *E = mzd_mul_m4rm(NULL, A, B, k);
151 mzd_add(E, E, C);
152
153 if (mzd_equal(D, E) != TRUE) {
154 printf(" M4RM != add,mul");
155 ret -=1;
156 }
157
158 /* F = C + A*B via naiv cubic multiplication */
159 mzd_t *F = mzd_copy(NULL, C);
160 F = mzd_addmul(F, A, B, cutoff);
161
162 if (mzd_equal(E, F) != TRUE) {
163 printf(" add,mul = addmul");
164 ret -=1;
165 }
166 if (mzd_equal(F, D) != TRUE) {
167 printf(" M4RM != addmul");
168 ret -=1;
169 }
170
171#if __M4RI_HAVE_OPENMP
172 mzd_t *G = mzd_copy(NULL, C);
173 G = mzd_addmul_mp(G, A, B, cutoff);
174 if (mzd_equal(D, G) != TRUE) {
175 printf(" MP != Naiv");
176 ret -= 1;
177 }
178 mzd_free(G);
179#endif
180
181 if (ret==0)
182 printf(" ... passed\n");
183 else
184 printf(" ... FAILED\n");
185
186 mzd_free(A);
187 mzd_free(B);
188 mzd_free(C);
189 mzd_free(D);
190 mzd_free(E);
191 mzd_free(F);
192 return ret;
193}
194
195int addsqr_test_equality(rci_t m, int k, int cutoff) {
196 int ret = 0;
197 mzd_t *A, *C, *D, *E, *F;
198
199 printf("addsqr: m: %4d, k: %2d, cutoff: %4d", m, k, cutoff);
200
201 /* we create two random matrices */
202 A = mzd_init(m, m);
203 C = mzd_init(m, m);
204 mzd_randomize(A);
205 mzd_randomize(C);
206
207 /* D = C + A*B via M4RM, temporary buffers are managed internally */
208 D = mzd_copy(NULL, C);
209 D = mzd_addmul_m4rm(D, A, A, k);
210
211 /* E = C + A*B via naive cubic multiplication */
212 E = mzd_mul_m4rm(NULL, A, A, k);
213 mzd_add(E, E, C);
214
215 /* F = C + A*B via naive cubic multiplication */
216 F = mzd_copy(NULL, C);
217 F = mzd_addmul(F, A, A, cutoff);
218
219 mzd_free(A);
220 mzd_free(C);
221
222 if (mzd_equal(D, E) != TRUE) {
223 printf(" M4RM != add,mul");
224 ret -=1;
225 }
226 if (mzd_equal(E, F) != TRUE) {
227 printf(" add,mul = addmul");
228 ret -=1;
229 }
230 if (mzd_equal(F, D) != TRUE) {
231 printf(" M4RM != addmul");
232 ret -=1;
233 }
234
235 if (ret==0)
236 printf(" ... passed\n");
237 else
238 printf(" ... FAILED\n");
239
240
241 mzd_free(D);
242 mzd_free(E);
243 mzd_free(F);
244 return ret;
245}
246
247int main() {
248 int status = 0;
249
250 srandom(17);
251
252 status += mul_test_equality( 1, 1, 1, 0, 1024);
253 status += mul_test_equality( 1, 128, 128, 0, 0);
254 status += mul_test_equality( 3, 131, 257, 0, 0);
255 status += mul_test_equality( 64, 64, 64, 0, 64);
256 status += mul_test_equality( 128, 128, 128, 0, 64);
257 status += mul_test_equality( 21, 171, 31, 0, 63);
258 status += mul_test_equality( 21, 171, 31, 0, 131);
259 status += mul_test_equality( 193, 65, 65, 8, 64);
260 status += mul_test_equality(1025, 1025, 1025, 3, 256);
261 status += mul_test_equality(2048, 2048, 4096, 0, 1024);
262 status += mul_test_equality(4096, 3528, 4096, 0, 1024);
263 status += mul_test_equality(1024, 1025, 1, 0, 1024);
264 status += mul_test_equality(1000, 1000, 1000, 0, 256);
265 status += mul_test_equality(1000, 10, 20, 0, 64);
266 status += mul_test_equality(1710, 1290, 1000, 0, 256);
267 status += mul_test_equality(1290, 1710, 200, 0, 64);
268 status += mul_test_equality(1290, 1710, 2000, 0, 256);
269 status += mul_test_equality(1290, 1290, 2000, 0, 64);
270 status += mul_test_equality(1000, 210, 200, 0, 64);
271
272 status += addmul_test_equality( 1, 128, 128, 0, 0);
273 status += addmul_test_equality( 3, 131, 257, 0, 0);
274 status += addmul_test_equality( 64, 64, 64, 0, 64);
275 status += addmul_test_equality( 128, 128, 128, 0, 64);
276 status += addmul_test_equality( 21, 171, 31, 0, 63);
277 status += addmul_test_equality( 21, 171, 31, 0, 131);
278 status += addmul_test_equality( 193, 65, 65, 8, 64);
279 status += addmul_test_equality(1025, 1025, 1025, 3, 256);
280 status += addmul_test_equality(4096, 4096, 4096, 0, 2048);
281 status += addmul_test_equality(1000, 1000, 1000, 0, 256);
282 status += addmul_test_equality(1000, 10, 20, 0, 64);
283 status += addmul_test_equality(1710, 1290, 1000, 0, 256);
284 status += addmul_test_equality(1290, 1710, 200, 0, 64);
285 status += addmul_test_equality(1290, 1710, 2000, 0, 256);
286 status += addmul_test_equality(1290, 1290, 2000, 0, 64);
287 status += addmul_test_equality(1000, 210, 200, 0, 64);
288
289 status += sqr_test_equality( 1, 0, 1024);
290 status += sqr_test_equality( 128, 0, 0);
291 status += sqr_test_equality( 131, 0, 0);
292 status += sqr_test_equality( 64, 0, 64);
293 status += sqr_test_equality( 128, 0, 64);
294 status += sqr_test_equality( 171, 0, 63);
295 status += sqr_test_equality( 171, 0, 131);
296 status += sqr_test_equality( 193, 8, 64);
297 status += sqr_test_equality(1025, 3, 256);
298 status += sqr_test_equality(2048, 0, 1024);
299 status += sqr_test_equality(3528, 0, 1024);
300 status += sqr_test_equality(1000, 0, 256);
301 status += sqr_test_equality(1000, 0, 64);
302 status += sqr_test_equality(1710, 0, 256);
303 status += sqr_test_equality(1290, 0, 64);
304 status += sqr_test_equality(2000, 0, 256);
305 status += sqr_test_equality(2000, 0, 64);
306 status += sqr_test_equality( 210, 0, 64);
307
308 status += addsqr_test_equality( 1, 0, 0);
309 status += addsqr_test_equality( 131, 0, 0);
310 status += addsqr_test_equality( 64, 0, 64);
311 status += addsqr_test_equality( 128, 0, 64);
312 status += addsqr_test_equality( 171, 0, 63);
313 status += addsqr_test_equality( 171, 0, 131);
314 status += addsqr_test_equality( 193, 8, 64);
315 status += addsqr_test_equality(1025, 3, 256);
316 status += addsqr_test_equality(4096, 0, 2048);
317 status += addsqr_test_equality(1000, 0, 256);
318 status += addsqr_test_equality(1000, 0, 64);
319 status += addsqr_test_equality(1710, 0, 256);
320 status += addsqr_test_equality(1290, 0, 64);
321 status += addsqr_test_equality(2000, 0, 256);
322 status += addsqr_test_equality(2000, 0, 64);
323 status += addsqr_test_equality( 210, 0, 64);
324
325 if (status == 0) {
326 printf("All tests passed.\n");
327 return 0;
328 } else {
329 return -1;
330 }
331}
Note: See TracBrowser for help on using the repository browser.