source: CIVL/examples/verifyThis/matrixMult.cvl

main
Last change on this file was ea777aa, checked in by Alex Wilton <awilton@…>, 3 years ago

Moved examples, include, build_default.properties, common.xml, and README out from dev.civl.com into the root of the repo.

git-svn-id: svn://vsl.cis.udel.edu/civl/trunk@5704 fb995dde-84ed-4084-dfe6-e5aef3e2452c

  • Property mode set to 100644
File size: 8.4 KB
Line 
1/* VerifyThis 2016 - Challenge 1: Matrix Multiplication
2 * Consider the following pseudocode algorithm, which is naive
3 * implementation of matrix multiplication. For simplicity we assume that
4 * the matrices are square.
5
6 int[][] matrixMultiply(int[][] A, int[][] B) {
7 int n = A.length;
8
9 // initialise C
10 int[][] C = new int[n][n];
11
12 for (int i = 0; i < n; i++) {
13 for (int k = 0; k < n; k++) {
14 for (int j = 0; j < n; j++) {
15 C[i][j] += A[i][k] * B[k][j];
16 }
17 }
18 }
19 return C;
20 }
21
22 * Tasks.
23 * (1) Provide a specification to describe the behaviour of this algorithm,
24 * and prove that it correctly implements its specification.
25 * (2) Show that matrix multiplication is associative, i.e., the order in
26 * which matrices are multiplied can be disregarded: A(BC) = (AB)C. To show
27 * this, you should write a program that performs the two different
28 * computations, and then prove that the result of the two computations is
29 * always the same.
30 * (3) In the literature, there exist many proposals for more efficient
31 * matrix multiplication algorithms. Strassen’s algorithm was one of the
32 * first. The key idea of the algorithm is to use a recurisive algorithm
33 * that reduces the number of multiplications on submatrices (from 8 to 7),
34 * see https://en.wikipedia.org/wiki/Strassen_algorithm for an explanation.
35 * A relatively clean Java implementation (and Python and C++) can be found
36 * here: https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/.
37 * Prove that the naive algorithm above has the same behaviour as Strassen’s
38 * algorithm. Proving it for a restricted case, like a 2x2 matrix should be
39 * straightforward, the challenge is to prove it for arbitrary matrices with
40 * size 2^n.
41 *
42 * Author: Stephen Siegel
43 */
44
45#include <civlc.cvh>
46#include <stdio.h>
47#include <pointer.cvh>
48
49// upper bound on N, the size of the matrices
50$input int BOUND = 4; // can go up to 16 if you have a few minutes
51$assume(BOUND >= 1);
52$input int N=4; // the size of the matrices
53$assume(1<=N && N<=BOUND);
54// some arbitrary input matrices...
55$input float A0[N][N];
56$input float B0[N][N];
57$input float C0[N][N];
58
59// the "leaf size" for Strassen...
60$input int LEAF_SIZE;
61$assume (2 <= LEAF_SIZE && LEAF_SIZE <= N);
62
63/* Part 1 */
64
65// impl: C is "out" variable
66void matrixMultiply(int n, float C[][], float A[][], float B[][]) {
67 for (int i=0; i<n; i++)
68 for (int j=0; j<n; j++)
69 C[i][j] = 0.0;
70
71 for (int i = 0; i < n; i++) {
72 for (int k = 0; k < n; k++) {
73 for (int j = 0; j < n; j++) {
74 C[i][j] += A[i][k] * B[k][j];
75 }
76 }
77 }
78}
79
80// Unfortunately no easy way to specify the sum...
81
82// Some "helpers" for verification...
83
84// Computes dot product of two vectors of length n...
85float dot(int n, float u[], float v[]) {
86 float sum = 0;
87
88 for (int i=0; i<n; i++)
89 sum += u[i]*v[i];
90 return sum;
91}
92
93// gets the index-th column of matrix mat, putting it in
94// contiguous memory starting from result. Returns
95// pointer to element 0 of result.
96float * column(int n, float result[], float mat[][], int index) {
97 for (int i=0; i<n; i++)
98 result[i] = mat[i][index];
99 return &result[0];
100}
101
102
103// checks multiplication is correct: entry i,j should
104// be the dot product of i-th row of A and j-th column of B...
105void testMult(int n) {
106 float actual[n][n];
107 float tmp[n];
108
109 matrixMultiply(n, actual, A0, B0);
110 for (int i=0; i<n; i++)
111 for (int j=0; j<n; j++)
112 $assert(dot(n, A0[i], column(n, tmp, B0, j)) == actual[i][j]);
113}
114
115
116
117/* Part 2 */
118
119// tests associativity: (A0*B0)*C0 = A0*(B0*C0)
120void assoc(int n) {
121 float T1[n][n], T2[n][n], R1[n][n], R2[n][n];
122
123 matrixMultiply(n, T1, A0, B0);
124 matrixMultiply(n, R1, T1, C0);
125 matrixMultiply(n, T2, B0, C0);
126 matrixMultiply(n, R2, A0, T2);
127 $assert($equals(&R1, &R2));
128}
129
130
131/* Part 3 : Strassen */
132
133// adds two nxn matrices. C is "out" variable.
134void add(int n, float C[][], float A[][], float B[][]) {
135 for (int i = 0; i < n; i++)
136 for (int j = 0; j < n; j++)
137 C[i][j] = A[i][j] + B[i][j];
138}
139
140// subtracts two nxn matrices. C is "out" variable.
141void subtract(int n, float C[][], float A[][], float B[][]) {
142 for (int i = 0; i < n; i++)
143 for (int j = 0; j < n; j++)
144 C[i][j] = A[i][j] - B[i][j];
145}
146
147
148// Strassen algorithm from
149// https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/
150// I'm just going to assume n is a power of 2.
151// There is no problem dealing with the general case but need more
152// time!
153
154// multiplies two nxn matrices, storing result in C
155void strassenR(int n, float C[][], float A[][], float B[][]) {
156 if (n <= LEAF_SIZE) {
157 matrixMultiply(n, C, A, B);
158 } else {
159 // initializing the new sub-matrices
160 int newSize = n / 2;
161 float a11[newSize][newSize];
162 float a12[newSize][newSize];
163 float a21[newSize][newSize];
164 float a22[newSize][newSize];
165
166 float b11[newSize][newSize];
167 float b12[newSize][newSize];
168 float b21[newSize][newSize];
169 float b22[newSize][newSize];
170
171 float aResult[newSize][newSize];
172 float bResult[newSize][newSize];
173
174 // dividing the matrices in 4 sub-matrices:
175 for (int i = 0; i < newSize; i++) {
176 for (int j = 0; j < newSize; j++) {
177 a11[i][j] = A[i][j]; // top left
178 a12[i][j] = A[i][j + newSize]; // top right
179 a21[i][j] = A[i + newSize][j]; // bottom left
180 a22[i][j] = A[i + newSize][j + newSize]; // bottom right
181
182 b11[i][j] = B[i][j]; // top left
183 b12[i][j] = B[i][j + newSize]; // top right
184 b21[i][j] = B[i + newSize][j]; // bottom left
185 b22[i][j] = B[i + newSize][j + newSize]; // bottom right
186 }
187 }
188 // Calculating p1 to p7:
189 add(newSize, aResult, a11, a22);
190 add(newSize, bResult, b11, b22);
191 float p1[newSize][newSize];
192 strassenR(newSize, p1, aResult, bResult);
193 // p1 = (a11+a22) * (b11+b22)
194
195 add(newSize, aResult, a21, a22); // a21 + a22
196 float p2[newSize][newSize];
197 strassenR(newSize, p2, aResult, b11); // p2 = (a21+a22) * (b11)
198
199 subtract(newSize, bResult, b12, b22); // b12 - b22
200 float p3[newSize][newSize];
201 strassenR(newSize, p3, a11, bResult);
202 // p3 = (a11) * (b12 - b22)
203
204 subtract(newSize, bResult, b21, b11); // b21 - b11
205 float p4[newSize][newSize];
206 strassenR(newSize, p4, a22, bResult);
207 // p4 = (a22) * (b21 - b11)
208
209 add(newSize, aResult, a11, a12); // a11 + a12
210 float p5[newSize][newSize];
211 strassenR(newSize, p5, aResult, b22);
212 // p5 = (a11+a12) * (b22)
213
214 subtract(newSize, aResult, a21, a11); // a21 - a11
215 add(newSize, bResult, b11, b12); // b11 + b12
216 float p6[newSize][newSize];
217 strassenR(newSize, p6, aResult, bResult);
218 // p6 = (a21-a11) * (b11+b12)
219
220 subtract(newSize, aResult, a12, a22); // a12 - a22
221 add(newSize, bResult, b21, b22); // b21 + b22
222 float p7[newSize][newSize];
223 strassenR(newSize, p7, aResult, bResult);
224 // p7 = (a12-a22) * (b21+b22)
225
226 // calculating c21, c21, c11 e c22:
227 float c12[newSize][newSize];
228 add(newSize, c12, p3, p5); // c12 = p3 + p5
229 float c21[newSize][newSize];
230 add(newSize, c21, p2, p4); // c21 = p2 + p4
231
232 add(newSize, aResult, p1, p4); // p1 + p4
233 add(newSize, bResult, aResult, p7); // p1 + p4 + p7
234 float c11[newSize][newSize];
235 subtract(newSize, c11, bResult, p5);
236 // c11 = p1 + p4 - p5 + p7
237
238 add(newSize, aResult, p1, p3); // p1 + p3
239 add(newSize, bResult, aResult, p6); // p1 + p3 + p6
240 float c22[newSize][newSize];
241 subtract(newSize, c22, bResult, p2);
242 // c22 = p1 + p3 - p2 + p6
243
244 // Grouping the results obtained in a single matrix:
245 for (int i = 0; i < newSize; i++) {
246 for (int j = 0; j < newSize; j++) {
247 C[i][j] = c11[i][j];
248 C[i][j + newSize] = c12[i][j];
249 C[i + newSize][j] = c21[i][j];
250 C[i + newSize][j + newSize] = c22[i][j];
251 }
252 }
253 }
254}
255
256// test Strassen multiplication agrees with the regular one...
257void testStrassen(int n) {
258 float R1[n][n], R2[n][n];
259
260 matrixMultiply(n, R1, A0, B0);
261 strassenR(n, R2, A0, B0);
262 $assert($equals(&R1, &R2));
263}
264
265// determines whether n is a power of 2
266_Bool isPowerOf2(int n) {
267 while (n>1) {
268 if (n%2 != 0)
269 return $false;
270 n = n/2;
271 }
272 return $true;
273}
274
275/* main: runs the three tests */
276int main() {
277 //$elaborate(N); // hint to verifier
278 printf("N=%d\n", N);
279 //testMult(N);
280 //assoc(N);
281 //$assume(isPowerOf2(N));
282 testStrassen(N);
283}
Note: See TracBrowser for help on using the repository browser.