SimpleArrayReshaper.java
package dev.civl.mc.semantics.common;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import dev.civl.mc.semantics.IF.ArrayReshaper;
import dev.civl.mc.semantics.IF.ArrayToolBox.ArrayShape;
import dev.civl.sarl.IF.SymbolicUniverse;
import dev.civl.sarl.IF.expr.NumericExpression;
import dev.civl.sarl.IF.expr.NumericSymbolicConstant;
import dev.civl.sarl.IF.expr.SymbolicExpression;
import dev.civl.sarl.IF.number.IntegerNumber;
import dev.civl.sarl.IF.number.Number;
import dev.civl.sarl.IF.type.SymbolicArrayType;
import dev.civl.sarl.IF.type.SymbolicCompleteArrayType;
import dev.civl.sarl.IF.type.SymbolicType;
/**
* A simple {@link ArrayReshaper} implementation. No reasoning will be
* performed. State independent.
*
* @author ziqing
*/
public class SimpleArrayReshaper implements ArrayReshaper {
SymbolicUniverse universe;
SimpleArrayReshaper(SymbolicUniverse universe) {
this.universe = universe;
}
@Override
public SymbolicExpression arrayFlatten(SymbolicExpression array,
ArrayShape shape) {
int dims = shape.dimensions;
if (dims == 1)
return array;
Number extentNums[] = new Number[dims];
for (int d = 0; d < dims; d++) {
if ((extentNums[d] = universe
.extractNumber(shape.extents[d])) == null)
break;
if (d == dims - 1) // last iteration
return arrayFlattenConcrete(array, extentNums, shape);
}
return arrayFlattenLambda(array, extentNums, shape);
}
@Override
public SymbolicExpression arrayReshape(SymbolicExpression array,
ArrayShape originShape, ArrayShape targetShape) {
Number targetExtentNums[] = new Number[targetShape.dimensions];
Number originExtentNums[] = new Number[originShape.dimensions];
// if two shapes are the same:
if (targetShape.dimensions == originShape.dimensions) {
boolean isSame = true;
for (int i = 0; i < targetShape.dimensions; i++)
if (!originShape.extents[i].equals(targetShape.extents[i])) {
isSame = false;
break;
}
if (isSame)
return array;
}
// if both shape have concrete shape:
for (int i = 0; i < targetShape.dimensions; i++) {
targetExtentNums[i] = universe
.extractNumber(targetShape.extents[i]);
if (targetExtentNums[i] == null)
return arrayReshapeLambda(array, originShape, targetShape);
}
for (int i = 0; i < originShape.dimensions; i++) {
originExtentNums[i] = universe
.extractNumber(originShape.extents[i]);
if (originExtentNums[i] == null)
return arrayReshapeLambda(array, originShape, targetShape);
}
return arrayReshapeConcrete(array, originShape, originExtentNums,
targetShape, targetExtentNums);
}
@Override
public boolean allComplete(SymbolicArrayType arrayType) {
return allCompleteWorker(arrayType);
}
/* ***************************** private methods *********************/
static private boolean allCompleteWorker(SymbolicArrayType arrayType) {
int dims = arrayType.dimensions();
for (int i = 0; i < dims - 1; i++) {
if (!arrayType.isComplete())
return false;
arrayType = (SymbolicArrayType) arrayType.elementType();
}
return true;
}
private SymbolicExpression arrayFlattenConcrete(SymbolicExpression array,
Number[] extentNums, ArrayShape shape) {
int extents[] = new int[extentNums.length];
for (int i = 0; i < extents.length; i++)
extents[i] = ((IntegerNumber) extentNums[i]).intValue();
List<SymbolicExpression> elements = arrayFlattenConcreteWorker(array,
extents, extents.length);
return universe.array(shape.baseType, elements);
}
private List<SymbolicExpression> arrayFlattenConcreteWorker(
SymbolicExpression array, int extents[], int dim) {
List<SymbolicExpression> elements = new LinkedList<>();
int extent = extents[extents.length - dim - 1];
if (dim == 0)
for (int i = 0; i < extent; i++)
elements.add(universe.arrayRead(array, universe.integer(i)));
else
for (int i = 0; i < extent; i++)
elements.addAll(arrayFlattenConcreteWorker(
universe.arrayRead(array, universe.integer(i)), extents,
dim - 1));
return elements;
}
private SymbolicExpression arrayFlattenLambda(SymbolicExpression array,
Number extentNums[], ArrayShape shape) {
int dims = shape.dimensions;
int newDims, i;
NumericExpression[] arraySliceSizes = shape.subArraySizes;
NumericExpression[] arrayExtents = shape.extents;
// pre-optimize 1: transform an array that has such form
// a[1][1][...][1][n][...][m] to a'[n][...][m]:
for (i = 0; i < dims - 1; i++)
if (shape.extents[i].isOne())
array = universe.arrayRead(array, universe.zeroInt());
else
break;
newDims = dims - i;
if (newDims == 1)
return array;
if (newDims < dims) {
arraySliceSizes = Arrays.copyOfRange(arraySliceSizes,
dims - newDims, dims);
extentNums = Arrays.copyOfRange(extentNums, dims - newDims, dims);
arrayExtents = Arrays.copyOfRange(shape.extents, dims - newDims,
dims);
dims = newDims;
}
// end of pre-optimize
return arrayFlattenLambdaWorker(array, arraySliceSizes, arrayExtents);
}
private SymbolicExpression arrayFlattenLambdaWorker(
SymbolicExpression array, NumericExpression[] arraySliceSizes,
NumericExpression[] arrayExtents) {
int dims = arrayExtents.length;
NumericSymbolicConstant symConst = (NumericSymbolicConstant) universe
.symbolicConstant(universe.stringObject("i"),
universe.integerType());
NumericExpression index = symConst;
SymbolicExpression arrayReadFunc = array;
for (int d = 0; d < dims; d++) {
arrayReadFunc = universe.arrayRead(arrayReadFunc,
universe.divide(index, arraySliceSizes[d]));
index = universe.modulo(index, arraySliceSizes[d]);
}
SymbolicCompleteArrayType arrayType = universe.arrayType(
arrayReadFunc.type(),
universe.multiply(arraySliceSizes[0], arrayExtents[0]));
return universe.arrayLambda(arrayType,
universe.lambda(symConst, arrayReadFunc));
}
private SymbolicExpression arrayReshapeConcrete(SymbolicExpression array,
ArrayShape originShape, Number[] originExtentNums,
ArrayShape targetShape, Number[] targetExtentNums) {
int originExtents[] = new int[originExtentNums.length];
int targetExtents[] = new int[targetExtentNums.length];
for (int i = 0; i < originExtents.length; i++)
originExtents[i] = ((IntegerNumber) originExtentNums[i]).intValue();
for (int i = 0; i < targetExtents.length; i++)
targetExtents[i] = ((IntegerNumber) targetExtentNums[i]).intValue();
// read all elements from origin
LinkedList<SymbolicExpression> queue = new LinkedList<>();
int queueSize, dimIter = 1;
for (int i = 0; i < originExtents[0]; i++)
queue.add(universe.arrayRead(array, universe.integer(i)));
queueSize = queue.size();
while (dimIter < originExtents.length) {
for (int i = 0; i < queueSize; i++) {
SymbolicExpression sub = queue.removeFirst();
for (int j = 0; j < originExtents[dimIter]; j++)
queue.add(universe.arrayRead(sub, universe.integer(j)));
}
dimIter++;
queueSize = queue.size();
}
// build new array using the list of total elements:
SymbolicType arrayType = targetShape.baseType;
for (int d = 0; d < targetShape.dimensions; d++) {
int extent = targetExtents[targetShape.dimensions - 1 - d];
int numSubArrays = queue.size() / extent;
for (int j = 0; j < numSubArrays; j++) {
List<SymbolicExpression> subArraysElements = new LinkedList<>();
for (int i = 0; i < extent; i++)
subArraysElements.add(queue.removeFirst());
queue.add(universe.array(arrayType, subArraysElements));
}
arrayType = universe.arrayType(arrayType, universe.integer(extent));
}
return queue.removeFirst();
}
private SymbolicExpression arrayReshapeLambda(SymbolicExpression array,
ArrayShape originShape, ArrayShape targetShape) {
NumericSymbolicConstant lambdaConstants[] = new NumericSymbolicConstant[targetShape.dimensions];
for (int i = 0; i < lambdaConstants.length; i++)
lambdaConstants[i] = (NumericSymbolicConstant) universe
.symbolicConstant(universe.stringObject("i_" + i),
universe.integerType());
NumericExpression indices4origin[] = arrayIndiceProjecting(
lambdaConstants, targetShape, originShape);
// build array lambda:
SymbolicType arrayType = targetShape.baseType;
SymbolicExpression lambdaFunction = array;
for (int i = 0; i < originShape.dimensions; i++)
lambdaFunction = universe.arrayRead(lambdaFunction,
indices4origin[i]);
for (int i = targetShape.dimensions - 1; i >= 0; i--) {
arrayType = universe.arrayType(arrayType, targetShape.extents[i]);
lambdaFunction = universe.lambda(lambdaConstants[i],
lambdaFunction);
lambdaFunction = universe.arrayLambda(
(SymbolicCompleteArrayType) arrayType, lambdaFunction);
}
return lambdaFunction;
}
@Override
public NumericExpression[] sliceIndiceProjecting(
NumericExpression[] fromSliceArrayIndices,
ArrayShape fromSliceArrayShape,
NumericExpression[] fromSliceStartIndices,
ArrayShape toSliceArrayShape,
NumericExpression[] toSliceStartIndices) {
// assert fromSliceArrayShape.baseType == toSliceArrayShape.baseType
assert fromSliceStartIndices.length <= fromSliceArrayShape.dimensions;
assert toSliceStartIndices.length <= toSliceArrayShape.dimensions;
NumericExpression pos = universe.zeroInt();
NumericExpression fromSliceStartOffsets = pos,
toSliceStartOffsets = pos;
for (int i = 0; i < fromSliceArrayIndices.length; i++)
pos = universe.add(pos, universe.multiply(fromSliceArrayIndices[i],
fromSliceArrayShape.subArraySizes[i]));
for (int i = 0; i < fromSliceStartIndices.length; i++)
fromSliceStartOffsets = universe.add(fromSliceStartOffsets,
universe.multiply(fromSliceStartIndices[i],
fromSliceArrayShape.subArraySizes[i]));
pos = universe.subtract(pos, fromSliceStartOffsets);
for (int i = 0; i < toSliceStartIndices.length; i++)
toSliceStartOffsets = universe.add(toSliceStartOffsets,
universe.multiply(toSliceStartIndices[i],
toSliceArrayShape.subArraySizes[i]));
pos = universe.add(pos, toSliceStartOffsets);
NumericExpression projectedIndices[] = new NumericExpression[toSliceArrayShape.dimensions];
for (int i = 0; i < projectedIndices.length; i++) {
projectedIndices[i] = universe.divide(pos,
toSliceArrayShape.subArraySizes[i]);
pos = universe.modulo(pos, toSliceArrayShape.subArraySizes[i]);
}
return projectedIndices;
}
@Override
public NumericExpression[] arrayIndiceProjecting(
NumericExpression[] fromArrayIndices, ArrayShape fromArrayShape,
ArrayShape toArrayShape) {
NumericExpression zero[] = {universe.zeroInt()};
return this.sliceIndiceProjecting(fromArrayIndices, fromArrayShape,
zero, toArrayShape, zero);
}
/* ********* Testing *********** */
// public static void main(String args[]) {
// SymbolicUniverse universe = SARL.newIdealUniverse();
// NumericSymbolicConstant n, m;
// SymbolicConstant arr4d, brr4d, brr3d;
// SymbolicArrayType arrayType1d, arrayType2d, arrayType3d, arrayType4d;
// SymbolicArrayType brrayType1d, brrayType2d, brrayType3d, brrayType4d;
// BooleanExpression context;
//
// n = (NumericSymbolicConstant) universe.symbolicConstant(
// universe.stringObject("n"), universe.integerType());
// m = (NumericSymbolicConstant) universe.symbolicConstant(
// universe.stringObject("m"), universe.integerType());
//
// context = universe.and(universe.lessThan(universe.zeroInt(), n),
// universe.lessThan(universe.zeroInt(), m));
//
// arrayType1d = universe.arrayType(universe.integerType(), n);
// arrayType2d = universe.arrayType(arrayType1d, m);
// arrayType3d = universe.arrayType(arrayType2d, universe.integer(4));
// arrayType4d = universe.arrayType(arrayType3d, universe.integer(3));
// arr4d = universe.symbolicConstant(universe.stringObject("arr4d"),
// arrayType4d);
//
// brrayType1d = universe.arrayType(universe.integerType(), m);
// brrayType2d = universe.arrayType(brrayType1d, n);
// brrayType3d = universe.arrayType(brrayType2d, universe.integer(3));
// brrayType4d = universe.arrayType(brrayType3d, universe.integer(4));
// brr4d = universe.symbolicConstant(universe.stringObject("brr4d"),
// brrayType4d);
// brr3d = universe.symbolicConstant(universe.stringObject("brr3d"),
// universe.arrayType(brrayType2d, universe.integer(12)));
//
// ArrayReshaper reshaper = new SimpleArrayReshaper(universe);
// SymbolicExpression flat_arr = reshaper.arrayFlatten(arr4d);
// SymbolicExpression arr4d_to_brr4d, flat_arr4d_to_brr4d, arr4d_to_brr3d,
// brr3d_to_arr4d;
//
// System.out.println(flat_arr);
//
// arr4d_to_brr4d = reshaper.arrayReshape(arr4d,
// (SymbolicArrayType) brr4d.type());
//
// assert reshaper.isPhysicallyEquivalent((SymbolicArrayType) arr4d.type(),
// (SymbolicArrayType) brr3d.type()).isTrue();
// assert reshaper.allComplete((SymbolicArrayType) arr4d.type());
//
// arr4d_to_brr3d = reshaper.arrayReshape(arr4d,
// (SymbolicArrayType) brr3d.type());
// flat_arr4d_to_brr4d = reshaper.arrayReshape(flat_arr,
// (SymbolicArrayType) brr4d.type());
//
// ValidityResult validity = universe.reasoner(universe.trueExpression())
// .validWhy3(
// universe.equals(arr4d_to_brr4d, flat_arr4d_to_brr4d));
//
// assert validity.getResultType() == ResultType.YES;
//
// // arr4d_to_brr3d [0][0][0] == arr4d[0][0][0][0]
// validity = universe
// .reasoner(
// universe.trueExpression())
// .validWhy3(
// universe.equals(universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// arr4d_to_brr3d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()), universe
// .arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// arr4d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt())));
// assert validity.getResultType() == ResultType.YES;
//
// // arr4d_to_brr3d [11][n-1][m-1] == arr4d[3][2][m-1][n-1]
// validity = universe
// .reasoner(
// context)
// .validWhy3(universe.equals(universe.arrayRead(
// universe.arrayRead(universe.arrayRead(arr4d_to_brr3d,
// universe.integer(11)),
// universe.subtract(n, universe.oneInt())),
// universe.subtract(m,
// universe.oneInt())),
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(arr4d,
// universe.integer(2)),
// universe.integer(3)),
// universe.subtract(m,
// universe.oneInt())),
// universe.subtract(n, universe.oneInt()))));
//
// assert validity.getResultType() == ResultType.YES;
//
// brr3d_to_arr4d = reshaper.arrayReshape(brr3d,
// (SymbolicArrayType) arr4d.type());
//
// // brr3d_to_arr4d[0][0][0][0] == brr3d[0][0][0]
// validity = universe.reasoner(universe.trueExpression())
// .validWhy3(
// universe.equals(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// brr3d_to_arr4d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(brr3d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt())));
// assert validity.getResultType() == ResultType.YES;
// // brr3d_to_arr4d[2][3][m-1][n-1] == brr3d[11][n-1][m-1]
// validity = universe
// .reasoner(
// context)
// .validWhy3(
// universe.equals(
// universe.arrayRead(universe
// .arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// brr3d_to_arr4d,
// universe.integer(
// 2)),
// universe.integer(3)),
// universe.subtract(m,
// universe.oneInt())),
// universe.subtract(n,
// universe.oneInt())),
// universe.arrayRead(
// universe.arrayRead(universe.arrayRead(
// brr3d, universe.integer(11)),
// universe.subtract(n,
// universe.oneInt())),
// universe.subtract(m,
// universe.oneInt()))));
//
// assert validity.getResultType() == ResultType.YES;
// // ******** nested reshaping test: **********
// // *** multi-d to 1-d
// SymbolicExpression brr3d_to_arr4d_to_1d = reshaper
// .arrayFlatten(brr3d_to_arr4d);
// SymbolicExpression brr3d_to_1d = reshaper.arrayFlatten(brr3d);
//
// // brr3d_to_4d[0][0][0][0] == brr3d_to_arr4d_to_1d[0]
// validity = universe.reasoner(context)
// .validWhy3(
// universe.equals(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// brr3d_to_arr4d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.arrayRead(brr3d_to_arr4d_to_1d,
// universe.zeroInt())));
//
// assert validity.getResultType() == ResultType.YES;
//
// // brr3d_to_4d[0][0][0][1] == brr3d_to_arr4d_to_1d[1]
// validity = universe.reasoner(context)
// .valid(universe
// .equals(universe
// .arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// brr3d_to_arr4d,
// universe.zeroInt()),
// universe.zeroInt()),
// universe.zeroInt()),
// universe.oneInt()),
// universe.arrayRead(brr3d_to_arr4d_to_1d,
// universe.oneInt())));
//
// assert validity.getResultType() != ResultType.NO;
//
// // brr3d_to_4d[2][3][m-1][n-1] == brr3d_to_arr4d_to_1d[12*m*n-1]
// validity = universe
// .reasoner(
// context)
// .valid(universe
// .equals(universe
// .arrayRead(universe
// .arrayRead(
// universe.arrayRead(
// universe.arrayRead(
// brr3d_to_arr4d,
// universe.integer(
// 2)),
// universe.integer(3)),
// universe.subtract(m,
// universe.oneInt())),
// universe.subtract(n,
// universe.oneInt())),
// universe.arrayRead(
// brr3d_to_arr4d_to_1d, universe
// .subtract(
// universe.multiply(
// Arrays.asList(
// universe.integer(
// 12),
// m, n)),
// universe.oneInt()))));
//
// assert validity.getResultType() != ResultType.NO;
//
// validity = universe.reasoner(context)
// .valid(universe.equals(brr3d_to_arr4d_to_1d, brr3d_to_1d));
// System.out.println(universe.equals(brr3d_to_arr4d_to_1d, brr3d_to_1d));
// assert validity.getResultType() != ResultType.NO;
// }
}