source: CIVL/examples/verifyThis/mm4.cvl

main
Last change on this file was ea777aa, checked in by Alex Wilton <awilton@…>, 3 years ago

Moved examples, include, build_default.properties, common.xml, and README out from dev.civl.com into the root of the repo.

git-svn-id: svn://vsl.cis.udel.edu/civl/trunk@5704 fb995dde-84ed-4084-dfe6-e5aef3e2452c

  • Property mode set to 100644
File size: 6.5 KB
Line 
1/* Post-commit solution to matrixMultiplication, using CIVL.
2 * Stephen Siegel
3 */
4
5#include <civlc.cvh>
6#include <stdio.h>
7#include <pointer.cvh>
8
9// upper bound on N, the size of the matrices
10$input int BOUND = 8; // can go up to 16 if you have a few minutes
11$assume(BOUND >= 1);
12$input int N; // the size of the matrices
13$assume(1<=N && N<=BOUND);
14// some arbitrary input matrices...
15$input float A0[N][N];
16$input float B0[N][N];
17$input float C0[N][N];
18
19// the "leaf size" for Strassen...
20$input int LEAF_SIZE;
21$assume (2 <= LEAF_SIZE && LEAF_SIZE <= N);
22
23/* Part 1 */
24
25// impl: C is "out" variable
26void matrixMultiply(int n, float C[][], float A[][], float B[][]) {
27 for (int i=0; i<n; i++)
28 for (int j=0; j<n; j++)
29 C[i][j] = 0.0;
30
31 for (int i = 0; i < n; i++) {
32 for (int k = 0; k < n; k++) {
33 for (int j = 0; j < n; j++) {
34 C[i][j] += A[i][k] * B[k][j];
35 }
36 }
37 }
38}
39
40// Unfortunately no easy way to specify the sum...
41
42// Some "helpers" for verification...
43
44// Computes dot product of two vectors of length n...
45float dot(int n, float u[], float v[]) {
46 float sum = 0;
47
48 for (int i=0; i<n; i++)
49 sum += u[i]*v[i];
50 return sum;
51}
52
53// gets the index-th column of matrix mat, putting it in
54// contiguous memory starting from result. Returns
55// pointer to element 0 of result.
56float * column(int n, float result[], float mat[][], int index) {
57 for (int i=0; i<n; i++)
58 result[i] = mat[i][index];
59 return &result[0];
60}
61
62
63// checks multiplication is correct: entry i,j should
64// be the dot product of i-th row of A and j-th column of B...
65void testMult(int n) {
66 float actual[n][n];
67 float tmp[n];
68
69 matrixMultiply(n, actual, A0, B0);
70 for (int i=0; i<n; i++)
71 for (int j=0; j<n; j++)
72 $assert(dot(n, A0[i], column(n, tmp, B0, j)) == actual[i][j]);
73}
74
75
76
77/* Part 2 */
78
79// tests accociativity: (A0*B0)*C0 = A0*(B0*C0)
80void assoc(int n) {
81 float T1[n][n], T2[n][n], R1[n][n], R2[n][n];
82
83 matrixMultiply(n, T1, A0, B0);
84 matrixMultiply(n, R1, T1, C0);
85 matrixMultiply(n, T2, B0, C0);
86 matrixMultiply(n, R2, A0, T2);
87 $assert($equals(&R1, &R2));
88}
89
90
91/* Part 3 : Strassen */
92
93// adds two nxn matrices. C is "out" variable.
94void add(int n, float C[][], float A[][], float B[][]) {
95 for (int i = 0; i < n; i++)
96 for (int j = 0; j < n; j++)
97 C[i][j] = A[i][j] + B[i][j];
98}
99
100// subtracts two nxn matrices. C is "out" variable.
101void subtract(int n, float C[][], float A[][], float B[][]) {
102 for (int i = 0; i < n; i++)
103 for (int j = 0; j < n; j++)
104 C[i][j] = A[i][j] - B[i][j];
105}
106
107
108// Strassen algorithm from
109// https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/
110// I'm just going to assume n is a power of 2.
111// There is no problem dealing with the general case but need more
112// time!
113
114// multiplies two nxn matrices, storing result in C
115void strassenR(int n, float C[][], float A[][], float B[][]) {
116 if (n <= LEAF_SIZE) {
117 matrixMultiply(n, C, A, B);
118 } else {
119 // initializing the new sub-matrices
120 int newSize = n / 2;
121 float a11[newSize][newSize];
122 float a12[newSize][newSize];
123 float a21[newSize][newSize];
124 float a22[newSize][newSize];
125
126 float b11[newSize][newSize];
127 float b12[newSize][newSize];
128 float b21[newSize][newSize];
129 float b22[newSize][newSize];
130
131 float aResult[newSize][newSize];
132 float bResult[newSize][newSize];
133
134 // dividing the matrices in 4 sub-matrices:
135 for (int i = 0; i < newSize; i++) {
136 for (int j = 0; j < newSize; j++) {
137 a11[i][j] = A[i][j]; // top left
138 a12[i][j] = A[i][j + newSize]; // top right
139 a21[i][j] = A[i + newSize][j]; // bottom left
140 a22[i][j] = A[i + newSize][j + newSize]; // bottom right
141
142 b11[i][j] = B[i][j]; // top left
143 b12[i][j] = B[i][j + newSize]; // top right
144 b21[i][j] = B[i + newSize][j]; // bottom left
145 b22[i][j] = B[i + newSize][j + newSize]; // bottom right
146 }
147 }
148 // Calculating p1 to p7:
149 add(newSize, aResult, a11, a22);
150 add(newSize, bResult, b11, b22);
151 float p1[newSize][newSize];
152 strassenR(newSize, p1, aResult, bResult);
153 // p1 = (a11+a22) * (b11+b22)
154
155 add(newSize, aResult, a21, a22); // a21 + a22
156 float p2[newSize][newSize];
157 strassenR(newSize, p2, aResult, b11); // p2 = (a21+a22) * (b11)
158
159 subtract(newSize, bResult, b12, b22); // b12 - b22
160 float p3[newSize][newSize];
161 strassenR(newSize, p3, a11, bResult);
162 // p3 = (a11) * (b12 - b22)
163
164 subtract(newSize, bResult, b21, b11); // b21 - b11
165 float p4[newSize][newSize];
166 strassenR(newSize, p4, a22, bResult);
167 // p4 = (a22) * (b21 - b11)
168
169 add(newSize, aResult, a11, a12); // a11 + a12
170 float p5[newSize][newSize];
171 strassenR(newSize, p5, aResult, b22);
172 // p5 = (a11+a12) * (b22)
173
174 subtract(newSize, aResult, a21, a11); // a21 - a11
175 add(newSize, bResult, b11, b12); // b11 + b12
176 float p6[newSize][newSize];
177 strassenR(newSize, p6, aResult, bResult);
178 // p6 = (a21-a11) * (b11+b12)
179
180 subtract(newSize, aResult, a12, a22); // a12 - a22
181 add(newSize, bResult, b21, b22); // b21 + b22
182 float p7[newSize][newSize];
183 strassenR(newSize, p7, aResult, bResult);
184 // p7 = (a12-a22) * (b21+b22)
185
186 // calculating c21, c21, c11 e c22:
187 float c12[newSize][newSize];
188 add(newSize, c12, p3, p5); // c12 = p3 + p5
189 float c21[newSize][newSize];
190 add(newSize, c21, p2, p4); // c21 = p2 + p4
191
192 add(newSize, aResult, p1, p4); // p1 + p4
193 add(newSize, bResult, aResult, p7); // p1 + p4 + p7
194 float c11[newSize][newSize];
195 subtract(newSize, c11, bResult, p5);
196 // c11 = p1 + p4 - p5 + p7
197
198 add(newSize, aResult, p1, p3); // p1 + p3
199 add(newSize, bResult, aResult, p6); // p1 + p3 + p6
200 float c22[newSize][newSize];
201 subtract(newSize, c22, bResult, p2);
202 // c22 = p1 + p3 - p2 + p6
203
204 // Grouping the results obtained in a single matrix:
205 for (int i = 0; i < newSize; i++) {
206 for (int j = 0; j < newSize; j++) {
207 C[i][j] = c11[i][j];
208 C[i][j + newSize] = c12[i][j];
209 C[i + newSize][j] = c21[i][j];
210 C[i + newSize][j + newSize] = c22[i][j];
211 }
212 }
213 }
214}
215
216// test Strassen multiplication agrees with the regular one...
217void testStrassen(int n) {
218 float R1[n][n], R2[n][n];
219
220 matrixMultiply(n, R1, A0, B0);
221 strassenR(n, R2, A0, B0);
222 $assert($equals(&R1, &R2));
223}
224
225// determines whether n is a power of 2
226_Bool isPowerOf2(int n) {
227 while (n>1) {
228 if (n%2 != 0)
229 return $false;
230 n = n/2;
231 }
232 return $true;
233}
234
235/* main: runs the three tests */
236int main() {
237 $elaborate(N); // hint to verifier
238 printf("N=%d\n", N);
239 testMult(N);
240 assoc(N);
241 $assume(isPowerOf2(N));
242 testStrassen(N);
243}
Note: See TracBrowser for help on using the repository browser.