SimpleArrayCutter.java

package dev.civl.mc.semantics.common;

import java.util.ArrayList;
import java.util.List;

import dev.civl.mc.semantics.IF.ArrayCutter;
import dev.civl.mc.semantics.IF.ArrayReshaper;
import dev.civl.mc.semantics.IF.ArrayToolBox.ArrayShape;
import dev.civl.sarl.IF.SymbolicUniverse;
import dev.civl.sarl.IF.expr.BooleanExpression;
import dev.civl.sarl.IF.expr.NumericExpression;
import dev.civl.sarl.IF.expr.NumericSymbolicConstant;
import dev.civl.sarl.IF.expr.SymbolicExpression;
import dev.civl.sarl.IF.number.IntegerNumber;
import dev.civl.sarl.IF.number.Number;
import dev.civl.sarl.IF.type.SymbolicType;

/**
 * A simple implementation of {@link ArrayCutter} where no prover is called.
 * 
 * @author ziqing
 *
 */
public class SimpleArrayCutter implements ArrayCutter {
	/**
	 * a reference to {@link SymbolicUniverse}.
	 */
	private SymbolicUniverse universe;

	/**
	 * a reference to {@link ArrayReshaper}
	 */
	private ArrayReshaper reshaper;

	SimpleArrayCutter(SymbolicUniverse universe, ArrayReshaper reshaper) {
		this.universe = universe;
		this.reshaper = reshaper;
	}

	@Override
	public SymbolicExpression arraySlice(SymbolicExpression array,
			ArrayShape shape, NumericExpression[] indices,
			NumericExpression count) {
		assert shape.dimensions == indices.length;
		Number countConcreteValue = universe.extractNumber(count);
		NumericExpression zero = universe.zeroInt();
		boolean isWholeArray = true;

		// special case 1: if the slice is the whole array
		for (NumericExpression idx : indices)
			if (idx != zero) {
				isWholeArray = false;
				break;
			}
		if (isWholeArray && shape.arraySize == count)
			return array;
		// special case 2: if count is concrete
		if (countConcreteValue != null)
			return arraySlice(array, shape, indices,
					((IntegerNumber) countConcreteValue).intValue());
		// general case:
		return arraySlice_general(array, shape, indices, count);
	}

	@Override
	public SymbolicExpression arraySlice(SymbolicExpression array,
			ArrayShape shape, NumericExpression[] indices, int count) {
		List<SymbolicExpression> elements = arraySlice_concreteCount(array,
				shape, indices, count);

		return universe.array(shape.baseType, elements);
	}

	@Override
	public SymbolicExpression arraySliceWrite(SymbolicExpression sliceArray,
			ArrayShape sliceArrayShape, NumericExpression[] sliceStartIndices,
			NumericExpression count, SymbolicExpression targetArray,
			ArrayShape targetShape, NumericExpression[] targetStartIndices) {
		assert sliceArrayShape.dimensions == sliceStartIndices.length;
		assert targetShape.dimensions == targetStartIndices.length;
		// baseTypes of sliceArrayShape and targetArray must be equivalent, but
		// the check requires reasoner, so it will not be checked as an
		// assertion here ...

		Number countConcreteValue = universe.extractNumber(count);

		if (countConcreteValue != null) {
			int countIntValue = ((IntegerNumber) countConcreteValue).intValue();
			int targetDims = targetShape.dimensions;
			ArrayList<SymbolicExpression> elements = arraySlice_concreteCount(
					sliceArray, sliceArrayShape, sliceStartIndices,
					countIntValue);
			NumericExpression writtenPos = universe.zeroInt();
			NumericExpression[] writtenIndices = new NumericExpression[targetDims];

			// general formula for computing written indices ...
			for (int i = 0; i < targetDims; i++)
				writtenPos = universe.add(writtenPos, universe.multiply(
						targetStartIndices[i], targetShape.subArraySizes[i]));
			for (int i = 0; i < countIntValue; i++) {
				NumericExpression currPos = universe.add(writtenPos,
						universe.integer(i));
				// TODO: optimize this:
				// write elements one by one ...
				for (int j = 0; j < targetDims; j++) {
					writtenIndices[j] = universe.divide(currPos,
							targetShape.subArraySizes[j]);
					currPos = universe.modulo(currPos,
							targetShape.subArraySizes[j]);
				}
				targetArray = mdArrayWrite(targetArray, writtenIndices,
						elements.get(i));
			}
			return targetArray;
		} else
			return sliceReadThenWrite_general(sliceArray, sliceArrayShape,
					sliceStartIndices, count, targetArray, targetShape,
					targetStartIndices);
	}

	@Override
	public SymbolicExpression mdArrayRead(SymbolicExpression array,
			NumericExpression indices[]) {
		SymbolicExpression result = array;

		for (int i = 0; i < indices.length; i++)
			result = universe.arrayRead(result, indices[i]);
		return result;
	}

	@Override
	public SymbolicExpression mdArrayWrite(SymbolicExpression array,
			NumericExpression indices[], SymbolicExpression value) {
		return mdArrayWriteWorker(array, indices, 0, value);
	}

	/* ************************** private methods **********************/

	private SymbolicExpression mdArrayWriteWorker(SymbolicExpression array,
			NumericExpression indices[], int indicesHead,
			SymbolicExpression value) {
		if (indices.length == indicesHead)
			return array;
		if (indices.length - indicesHead == 1)
			return universe.arrayWrite(array, indices[indicesHead], value);
		else
			return universe.arrayWrite(array, indices[indicesHead],
					mdArrayWriteWorker(
							universe.arrayRead(array, indices[indicesHead]),
							indices, indicesHead + 1, value));
	}
	/**
	 * <p>
	 * The general solution for "array slice read-then-write" operation. Reading
	 * an array slice from a "data" array <code>d</code>, starting from indices
	 * <code>d_I</code>, then writes the slice into a target array
	 * <code>t</code>, starting from indices <code>t_I</code>.
	 * </p>
	 * 
	 * <p>
	 * This solution uses array lambda and conditional expression to represent
	 * such operation.
	 * </p>
	 * 
	 * @param dataArray
	 *            the array where a slice is carved out
	 * @param dataShape
	 *            the shape of the dataArray
	 * @param dataStartIndices
	 *            the starting indices of the slice.
	 *            <code>dataStartIndices.length == dataShape.dimensions</code>
	 * @param count
	 *            the number of elements in the slice
	 * @param targetArray
	 *            the array where the slice will be written into
	 * @param targetShape
	 *            the shape of the targetArray
	 * @param targetStartIndices
	 *            the starting indices of where the slice being written.
	 *            <code>targetStartIndices.length == targetShape.dimensions</code>
	 * @return the updated targetArray
	 */
	private SymbolicExpression sliceReadThenWrite_general(
			SymbolicExpression dataArray, ArrayShape dataShape,
			NumericExpression[] dataStartIndices, NumericExpression count,
			SymbolicExpression targetArray, ArrayShape targetShape,
			NumericExpression[] targetStartIndices) {
		int targetDims = targetShape.dimensions;
		NumericSymbolicConstant symConsts[] = new NumericSymbolicConstant[targetDims];

		for (int i = 0; i < symConsts.length; i++)
			symConsts[i] = (NumericSymbolicConstant) universe.symbolicConstant(
					universe.stringObject("i" + i), universe.integerType());

		// compute indices that can access data array as if it has the target
		// shape
		NumericExpression projectedConstsToDataArray[] = reshaper
				.sliceIndiceProjecting(symConsts, targetShape,
						targetStartIndices, dataShape, dataStartIndices);
		SymbolicExpression dataSliceLambda = dataArray;
		int dataDims = dataShape.dimensions;

		// read slice in data array as if it has the target shape
		for (int i = 0; i < dataDims; i++)
			dataSliceLambda = universe.arrayRead(dataSliceLambda,
					projectedConstsToDataArray[i]);

		// the range of the slice in the target array:
		NumericExpression lower = universe.zeroInt();
		// upper = lower + count (exclusive)

		for (int i = 0; i < targetDims; i++)
			if (targetStartIndices[i].isZero())
				continue;
			else
				lower = universe.add(lower, universe.multiply(
						targetStartIndices[i], targetShape.subArraySizes[i]));

		// lower <= position < upper ? dataSliceLambda : targetArrayLambda
		NumericExpression pos = universe.zeroInt();
		BooleanExpression cond;
		SymbolicExpression targetArrayLambda, resultLambda;

		for (int i = 0; i < targetDims; i++)
			pos = universe.add(pos, universe.multiply(symConsts[i],
					targetShape.subArraySizes[i]));
		cond = universe.and(universe.lessThanEquals(lower, pos),
				universe.lessThan(pos, universe.add(lower, count)));
		targetArrayLambda = mdArrayRead(targetArray, symConsts);
		resultLambda = universe.cond(cond, dataSliceLambda, targetArrayLambda);

		// build array lambda then return ...
		SymbolicType elementType = resultLambda.type();
		for (int i = targetDims - 1; i >= 0; i--) {
			resultLambda = universe.lambda(symConsts[i], resultLambda);
			resultLambda = universe.arrayLambda(
					universe.arrayType(elementType, targetShape.extents[i]),
					resultLambda);
			elementType = resultLambda.type();
		}
		return resultLambda;
	}

	/**
	 * Read a slice from an array from a starting indices
	 * 
	 * @param array
	 *            the array where a slice will be carved out
	 * @param shape
	 *            the shape of the array
	 * @param indices
	 *            the starting indices of the slice
	 *            <code>indices.length == shape.dimensions</code>
	 * @param count
	 *            the number of elements in the slice.
	 *            <code>element.type() == shape.baseType</code>
	 * @return a slice which is represented as an {@link ArrayList} of elements
	 */
	private ArrayList<SymbolicExpression> arraySlice_concreteCount(
			SymbolicExpression array, ArrayShape shape,
			NumericExpression[] indices, int count) {
		ArrayList<SymbolicExpression> elements = new ArrayList<>();
		NumericExpression pos = universe.zeroInt();

		for (int j = 0; j < indices.length; j++)
			pos = universe.add(pos,
					universe.multiply(indices[j], shape.subArraySizes[j]));

		for (int i = 0; i < count; i++) {
			NumericExpression remain = universe.add(pos, universe.integer(i));

			for (int j = 0; j < indices.length; j++) {
				indices[j] = universe.divide(remain, shape.subArraySizes[j]);
				remain = universe.modulo(remain, shape.subArraySizes[j]);
			}
			elements.add(mdArrayRead(array, indices));
		}
		return elements;
	}

	/**
	 * Read a slice from an array from a starting indices
	 * 
	 * @param array
	 *            the array where a slice will be carved out
	 * @param shape
	 *            the shape of the array
	 * @param indices
	 *            the starting indices of the slice
	 *            <code>indices.length == shape.dimensions</code>
	 * @param count
	 *            the number of elements in the slice.
	 *            <code>element.type() == shape.baseType</code>
	 * @return a slice which is represented as an {@link ArrayList} of elements
	 */
	private SymbolicExpression arraySlice_general(SymbolicExpression array,
			ArrayShape arrayShape, NumericExpression[] indices,
			NumericExpression count) {
		// a 1d array lambda ...
		NumericSymbolicConstant symConst = (NumericSymbolicConstant) universe
				.symbolicConstant(universe.stringObject("i"),
						universe.integerType());
		NumericExpression pos = universe.zeroInt();

		for (int i = 0; i < indices.length; i++)
			pos = universe.add(pos,
					universe.multiply(indices[i], arrayShape.subArraySizes[i]));

		SymbolicExpression lambdaFunc = array;

		pos = universe.add(pos, symConst);
		for (int i = 0; i < indices.length; i++) {
			lambdaFunc = universe.arrayRead(lambdaFunc,
					universe.divide(pos, arrayShape.subArraySizes[i]));
			pos = universe.modulo(pos, arrayShape.subArraySizes[i]);
		}
		return universe.arrayLambda(
				universe.arrayType(arrayShape.baseType, count),
				universe.lambda(symConst, lambdaFunc));
	}
}