ClauseTransformGuideGenerator.java

package dev.civl.mc.transform.common.contracts;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import dev.civl.abc.ast.IF.ASTFactory;
import dev.civl.abc.ast.node.IF.IdentifierNode;
import dev.civl.abc.ast.node.IF.NodeFactory;
import dev.civl.abc.ast.node.IF.acsl.MPIContractExpressionNode;
import dev.civl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import dev.civl.abc.ast.node.IF.expression.ConstantNode;
import dev.civl.abc.ast.node.IF.expression.EnumerationConstantNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode.ExpressionKind;
import dev.civl.abc.ast.node.IF.expression.FunctionCallNode;
import dev.civl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode.Operator;
import dev.civl.abc.ast.node.IF.expression.RegularRangeNode;
import dev.civl.abc.ast.node.IF.expression.RemoteOnExpressionNode;
import dev.civl.abc.ast.node.IF.statement.BlockItemNode;
import dev.civl.abc.ast.node.IF.statement.StatementNode;
import dev.civl.abc.ast.node.IF.type.ArrayTypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode.TypeNodeKind;
import dev.civl.abc.ast.type.IF.PointerType;
import dev.civl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import dev.civl.abc.ast.type.IF.Type;
import dev.civl.abc.ast.value.IF.Value;
import dev.civl.abc.ast.value.IF.ValueFactory.Answer;
import dev.civl.abc.token.IF.Source;
import dev.civl.abc.token.IF.SyntaxException;
import dev.civl.abc.token.IF.TokenFactory;
import dev.civl.abc.util.IF.Pair;
import dev.civl.mc.model.IF.CIVLSyntaxException;
import dev.civl.mc.transform.SubstituteGuide;
import dev.civl.mc.transform.common.BaseWorker;
import dev.civl.mc.transform.common.contracts.FunctionContractBlock.ContractClause;
import dev.civl.mc.util.IF.Triple;

/**
 * <p>
 * This class generates {@link ClauseTransformGuide} for contract clauses such
 * as <code>requires</code> and <code>ensures</code> clauses. An instance of
 * {@link ClauseTransformGuide} corresponds to one contract clause.
 * </p>
 * <p>
 * This class only contains static method hence no runtime instance of this
 * class is needed.
 * </p>
 * 
 * @author ziqingluo
 *
 */
public class ClauseTransformGuideGenerator {

	/**
	 * a clause transformation guide mainly consists of the following
	 * informations:
	 * <ul>
	 * 
	 * <li>clause : the contract clause which specifies boolean expressions</li>
	 * <li>conditions (a.k.a assumptions): the assumption over the clause:
	 * <code>conditions IMPLIES expressions</code></li>
	 * <li>arrivends : a set of expressions representing processes that this
	 * clause depende on</li>
	 * <li>side conditions: a set of boolean expressions that must be proved
	 * (implied) by the contract</li>
	 * <li>prefix: a set of intermediate variables and statements that must come
	 * BEFORE the evaluation of the clause expression.</li>
	 * <li>suffix: a set of intermediate variables and statements that must come
	 * AFTER the evaluation of the clause expression.</li>
	 * <li>substitutions: a set of substitutions that will modify the clause
	 * expressions</li>
	 * </ul>
	 * 
	 * @author ziqingluo
	 *
	 */
	static class ClauseTransformGuide {
		final ContractClause clause;
		List<ExpressionNode> conditions;
		List<ExpressionNode> arrivends;
		List<ExpressionNode> sideConditions;
		List<BlockItemNode> prefix;
		Map<ExpressionNode, SubstituteGuide> substitutions;
		List<BlockItemNode> suffix;
		/**
		 * a cache that helps reduce the number of intermediate variables for
		 * handling MPI_datatypes and MPI_extents
		 */
		Map<String, String> mpiDatatypeToIntermediateName;
		/**
		 * a counter for assigning unique names for intermediate variables
		 */
		int nameCounter;

		ClauseTransformGuide(ContractClause clause,
				List<ExpressionNode> conditions, List<ExpressionNode> arrivends,
				Map<String, String> mpiDatatype2IntermediateName,
				int nameCounter) {
			prefix = new LinkedList<>();
			substitutions = new HashMap<>();
			suffix = new LinkedList<>();
			mpiDatatypeToIntermediateName = mpiDatatype2IntermediateName;
			this.nameCounter = nameCounter;
			this.clause = clause;
			this.conditions = conditions;
			this.arrivends = arrivends;
			this.sideConditions = new LinkedList<>();
		}
	}

	public static void transformAssume(ContractClause clause,
			ASTFactory astFactory, boolean isLocal, boolean useRankAsPID,
			ExpressionNode civlcPreState, ClauseTransformGuide out)
			throws SyntaxException {
		if (clause.specialReferences != null) {
			transformRemoteOnExpression(clause, astFactory, isLocal, out);
			transformAcslOldExpression(clause, astFactory, civlcPreState,
					isLocal, useRankAsPID, out);
			transformAcslResult(clause, astFactory, out);
			transformMPIExtentAndDatatype(clause, astFactory, isLocal, out);
			transformAcslValid(clause, astFactory, true, out);
			transformMPIValid(clause, astFactory, true, out);
		}
	}

	public static void transformAssert(ContractClause clause,
			ASTFactory astFactory, boolean isLocal, boolean useRankAsPID,
			ExpressionNode civlcPreState, ClauseTransformGuide out)
			throws SyntaxException {
		if (clause.specialReferences != null) {
			transformRemoteOnExpression(clause, astFactory, isLocal, out);
			transformAcslOldExpression(clause, astFactory, civlcPreState,
					isLocal, useRankAsPID, out);
			transformAcslResult(clause, astFactory, out);
			transformMPIExtentAndDatatype(clause, astFactory, isLocal, out);
			// no need to transform valid (and MPI_valid)
		}
	}

	/* *********** Methods transforming special expressions ********** */
	private static void transformRemoteOnExpression(ContractClause clause,
			ASTFactory astFactory, boolean isLocal, ClauseTransformGuide out) {
		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.remoteExpressions) {
			if (isLocal)
				throw new CIVLSyntaxException(
						"Remote expressions are not allowed in local contract blocks",
						expr.getSource());

			RemoteOnExpressionNode on = (RemoteOnExpressionNode) expr;

			out.substitutions.put(on,
					new ValueAtNodeSubstituteGuide(
							nf.newStatenullNode(expr.getSource()),
							on.getProcessExpression(),
							on.getForeignExpressionNode(), on));
		}
	}

	private static void transformAcslOldExpression(ContractClause clause,
			ASTFactory astFactory, ExpressionNode civlcPreState,
			boolean isLocal, boolean useRankAsPID, ClauseTransformGuide out)
			throws SyntaxException {
		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.acslOldExpressions) {
			if (civlcPreState == null)
				throw new CIVLSyntaxException(
						"\\old expressions are not allowed in post-condition",
						expr.getSource());

			OperatorNode old = (OperatorNode) expr;
			ExpressionNode proc = !useRankAsPID
					// TODO: need an expression represent current process:
					? nf.newIntegerConstantNode(old.getSource(), "0")
					: identifierExpression(nf,
							MPIContractUtilities.MPI_COMM_RANK_CONST,
							old.getSource());

			out.substitutions.put(old,
					new ValueAtNodeSubstituteGuide(
							nf.newOperatorNode(civlcPreState.getSource(),
									Operator.DEREFERENCE, civlcPreState.copy()),
							proc, old.getArgument(0), old));
		}
	}

	private static void transformAcslResult(ContractClause clause,
			ASTFactory astFactory, ClauseTransformGuide out) {
		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.acslResults) {
			ExpressionNode resultVar = identifierExpression(nf,
					MPIContractUtilities.ACSL_RESULT_VAR, expr.getSource());

			out.substitutions.put(expr,
					new CommonASTNodeSubstituteGuide(resultVar, expr));
		}
	}

	private static void transformMPIExtentAndDatatype(ContractClause clause,
			ASTFactory astFactory, boolean isLocal, ClauseTransformGuide out) {
		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.mpiExtents) {
			if (isLocal)
				throw new CIVLSyntaxException(
						"MPI contract expressions are not allowed in local "
								+ "contract blocks",
						expr.getSource());

			MPIContractExpressionNode mpiExtent = (MPIContractExpressionNode) expr;
			ExpressionNode datatype = mpiExtent.getArgument(0);
			ExpressionNode subst = transformMPIExtentAndDatatypeWorker(nf,
					datatype, out);

			out.substitutions.put(mpiExtent,
					new CommonASTNodeSubstituteGuide(subst, mpiExtent));
		}
		for (ExpressionNode datatype : clause.specialReferences.mpiDatatypes) {
			ExpressionNode subst = transformMPIExtentAndDatatypeWorker(nf,
					datatype, out);

			out.substitutions.put(datatype,
					new CommonASTNodeSubstituteGuide(subst, datatype));
		}
	}

	private static ExpressionNode transformMPIExtentAndDatatypeWorker(
			NodeFactory nf, ExpressionNode datatype, ClauseTransformGuide out) {
		String datatypeIdentifier = getDatatypeIdentifier(datatype);
		String intermediateName = out.mpiDatatypeToIntermediateName
				.get(datatypeIdentifier);

		if (intermediateName == null) {
			VariableDeclarationNode intermediateVarDecl;
			TypeNode type = size_t(nf, datatype.getSource());

			intermediateName = MPIContractUtilities
					.nextExtentName(out.nameCounter++);
			intermediateVarDecl = nf.newVariableDeclarationNode(
					datatype.getSource(),
					nf.newIdentifierNode(datatype.getSource(),
							intermediateName),
					type, createMPISizeofDatatypeCall(nf, datatype));
			out.prefix.add(intermediateVarDecl);
			if (datatypeIdentifier != null)
				out.mpiDatatypeToIntermediateName.put(datatypeIdentifier,
						intermediateName);
		}
		return identifierExpression(nf, intermediateName, datatype.getSource());
	}

	private static void transformAcslValid(ContractClause clause,
			ASTFactory astFactory, boolean isAssume, ClauseTransformGuide out)
			throws SyntaxException {
		assert isAssume;

		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.acslValidExpressions) {
			OperatorNode valid = (OperatorNode) expr;
			Triple<ExpressionNode, ExpressionNode, ExpressionNode> pointer_offset_extent = processAcslValidWorker(
					nf, valid);
			ExpressionNode pointer = pointer_offset_extent.first;
			ExpressionNode offset = pointer_offset_extent.second;
			ExpressionNode extent = pointer_offset_extent.third;
			ExpressionNode subst;
			TypeNode elementType = getPointerReferredTypeNode(nf, pointer);

			if (offset != null)
				if (!offset.isConstantExpression() || ((ConstantNode) offset)
						.getConstantValue().isZero() != Answer.YES)
					pointer = nf.newOperatorNode(pointer.getSource(),
							Operator.PLUS, pointer.copy(), offset);
			subst = createAllocation(astFactory, clause, pointer, elementType,
					extent, expr.getSource(), out);
			out.substitutions.put(expr,
					new CommonASTNodeSubstituteGuide(subst, expr));
		}
	}

	private static void transformMPIValid(ContractClause clause,
			ASTFactory astFactory, boolean isAssume, ClauseTransformGuide out)
			throws SyntaxException {
		assert isAssume;

		NodeFactory nf = astFactory.getNodeFactory();

		for (ExpressionNode expr : clause.specialReferences.mpiValidExpressions) {
			MPIContractExpressionNode mpiValid = (MPIContractExpressionNode) expr;
			ExpressionNode buf = mpiValid.getArgument(0);
			ExpressionNode count = mpiValid.getArgument(1);
			ExpressionNode datatype = mpiValid.getArgument(2);
			String datatypeIdent = getDatatypeIdentifier(datatype);
			String intermediateName = out.mpiDatatypeToIntermediateName
					.get(datatypeIdent);
			ExpressionNode extent;
			TypeNode elementType;

			if (intermediateName == null)
				extent = createMPISizeofDatatypeCall(nf, datatype);
			else
				extent = identifierExpression(nf, intermediateName,
						datatype.getSource());
			extent = nf.newOperatorNode(count.getSource(), Operator.TIMES,
					extent, count.copy());
			/*
			 * TODO: This is a somehow fishy compromising for the two different
			 * cases of allocation. If the datatype argument is an MPI_Datatype
			 * enumeration constant, we would use its associated concrete type
			 * to allocate memory objects. Otherwise, we would use the general
			 * solution (char).
			 */
			if (datatypeIdent.startsWith("MPI_")) {
				String typedefName = "_" + datatypeIdent + "_t";

				elementType = nf.newTypedefNameNode(
						nf.newIdentifierNode(datatype.getSource(), typedefName),
						null);
			} else
				elementType = nf.newBasicTypeNode(datatype.getSource(),
						BasicTypeKind.CHAR);

			ExpressionNode subst = createAllocation(astFactory, clause, buf,
					elementType, extent, expr.getSource(), out);

			out.substitutions.put(expr,
					new CommonASTNodeSubstituteGuide(subst, expr));
		}
	}

	/* ********************** Public Utils *********************** */
	public static Pair<ExpressionNode, ExpressionNode> processAcslValid(
			NodeFactory nf, OperatorNode valid) throws SyntaxException {
		Triple<ExpressionNode, ExpressionNode, ExpressionNode> pointer_offset_extent = processAcslValidWorker(
				nf, valid);
		ExpressionNode pointer = pointer_offset_extent.first;
		ExpressionNode offset = pointer_offset_extent.second;

		if (offset != null)
			if (!offset.isConstantExpression() || ((ConstantNode) offset)
					.getConstantValue().isZero() != Answer.YES)
				pointer = nf.newOperatorNode(pointer.getSource(), Operator.PLUS,
						pointer.copy(), offset);
		return new Pair<>(pointer, pointer_offset_extent.third);
	}

	public static Pair<ExpressionNode, ExpressionNode> processMPIValid(
			NodeFactory nf, MPIContractExpressionNode mpiValid)
			throws SyntaxException {
		ExpressionNode buf = mpiValid.getArgument(0);
		ExpressionNode count = mpiValid.getArgument(1);
		ExpressionNode datatype = mpiValid.getArgument(2);
		ExpressionNode extent;

		extent = createMPISizeofDatatypeCall(nf, datatype);
		extent = nf.newOperatorNode(count.getSource(), Operator.TIMES, extent,
				count.copy());
		return new Pair<>(buf, extent);
	}

	/* ********************** Private Utils *********************** */
	private static ExpressionNode identifierExpression(NodeFactory nf,
			String name, Source source) {
		return nf.newIdentifierExpressionNode(source,
				nf.newIdentifierNode(source, name));
	}

	private static String getDatatypeIdentifier(ExpressionNode mpiDatatype) {
		String datatypeIdentifier = null;

		if (mpiDatatype
				.expressionKind() == ExpressionKind.IDENTIFIER_EXPRESSION)
			datatypeIdentifier = ((IdentifierExpressionNode) mpiDatatype)
					.getIdentifier().name();
		else if (mpiDatatype instanceof EnumerationConstantNode)
			datatypeIdentifier = ((EnumerationConstantNode) mpiDatatype)
					.getName().name();
		return datatypeIdentifier;
	}

	private static TypeNode size_t(NodeFactory nf, Source source) {
		return nf.newTypedefNameNode(nf.newIdentifierNode(source, "size_t"),
				null);
	}

	/**
	 * Create a <code>sizeofDatatype(MPI_Datatype datatype)</code> call.
	 * 
	 * @param datatype
	 * @return
	 */
	private static ExpressionNode createMPISizeofDatatypeCall(NodeFactory nf,
			ExpressionNode datatype) {
		ExpressionNode callIdentifier = identifierExpression(nf,
				MPIContractUtilities.MPI_SIZEOF_DATATYPE, datatype.getSource());

		return nf.newFunctionCallNode(datatype.getSource(), callIdentifier,
				Arrays.asList(datatype.copy()), null);
	}

	private static Triple<ExpressionNode, ExpressionNode, ExpressionNode> decomposeRange(
			NodeFactory nf, RegularRangeNode range) throws SyntaxException {
		ExpressionNode low = range.getLow().copy();
		ExpressionNode high = range.getHigh().copy();
		Value constantVal = nf.getConstantValue(low);
		ExpressionNode count = constantVal.isZero() != Answer.YES
				? nf.newOperatorNode(range.getSource(), Operator.MINUS, high,
						low)
				: high;

		count = nf.newOperatorNode(low.getSource(), Operator.PLUS, count,
				nf.newIntegerConstantNode(range.getSource(), "1"));
		return new Triple<>(low, high, count);
	}

	private static TypeNode getPointerReferredTypeNode(NodeFactory nf,
			ExpressionNode pointer) throws SyntaxException {
		Type referredType = ((PointerType) pointer.getType()).referencedType();

		return BaseWorker.typeNode(pointer.getSource(), referredType, nf);
	}

	private static ExpressionNode createAllocation(ASTFactory af,
			ContractClause clause, ExpressionNode pointer, TypeNode elementType,
			ExpressionNode numElements, Source source, ClauseTransformGuide out)
			throws SyntaxException {
		NodeFactory nf = af.getNodeFactory();
		TypeNode arrayType = nf.newArrayTypeNode(source, elementType.copy(),
				numElements.copy());
		String allocationName = MPIContractUtilities
				.nextAllocationName(out.nameCounter++);
		IdentifierNode allocationIdentifierNode;

		pointer = ContractTransformerWorker.decast(pointer);
		allocationIdentifierNode = nf.newIdentifierNode(pointer.getSource(),
				allocationName);

		VariableDeclarationNode artificialVariable = nf
				.newVariableDeclarationNode(source, allocationIdentifierNode,
						arrayType);
		// assign allocated object to pointer;
		ExpressionNode assign = nf
				.newOperatorNode(source, Operator.ASSIGN,
						Arrays.asList(pointer.copy(),
								nf.newIdentifierExpressionNode(source,
										allocationIdentifierNode.copy())));
		ExpressionNode extentGTzero = arrayExtentsGTZero(nf,
				artificialVariable.getTypeNode(), source);

		// For allocation, array objects need assumptions for valid extents;
		// variables as memory objects must be inserted in some place where
		// is visible to all contracts...
		// TODO: use assume push might be better here:
		out.prefix.add(createAssumption(nf, extentGTzero));
		out.prefix.add(artificialVariable);
		out.sideConditions.add(extentGTzero);

		ExpressionNode conditions = conjunct(af, out.conditions);

		if (conditions != null)
			out.prefix.add(nf.newIfNode(source, conditions,
					nf.newExpressionStatementNode(assign)));
		else
			out.prefix.add(nf.newExpressionStatementNode(assign));
		return nf.newBooleanConstantNode(source, true);
	}

	private static ExpressionNode arrayExtentsGTZero(NodeFactory nf,
			TypeNode type, Source source) throws SyntaxException {
		if (type.kind() != TypeNodeKind.ARRAY)
			return null;

		ArrayTypeNode arrayType = (ArrayTypeNode) type;
		ExpressionNode extentsGTZero = arrayExtentsGTZero(nf,
				arrayType.getElementType(), source);
		ExpressionNode myExtentGTZero = nf.newOperatorNode(source, Operator.GT,
				arrayType.getExtent().copy(),
				nf.newIntegerConstantNode(source, "0"));

		if (extentsGTZero == null)
			extentsGTZero = myExtentGTZero;
		else
			extentsGTZero = nf.newOperatorNode(source, Operator.LAND,
					extentsGTZero, myExtentGTZero);
		return extentsGTZero;
	}

	private static StatementNode createAssumption(NodeFactory nf,
			ExpressionNode pred) {
		ExpressionNode assumeIdentifier = identifierExpression(nf,
				BaseWorker.ASSUME, pred.getSource());
		FunctionCallNode assumeCall = nf.newFunctionCallNode(pred.getSource(),
				assumeIdentifier, Arrays.asList(pred.copy()), null);

		return nf.newExpressionStatementNode(assumeCall);
	}

	private static Triple<ExpressionNode, ExpressionNode, ExpressionNode> processAcslValidWorker(
			NodeFactory nf, OperatorNode valid) throws SyntaxException {
		ExpressionNode arg = valid.getArgument(0);
		ExpressionNode pointer, extent, offset = null;

		// Check if the argument of valid is in a limited form:
		if (arg.expressionKind() == ExpressionKind.OPERATOR) {
			OperatorNode opNode = (OperatorNode) arg;
			ExpressionNode range;

			if (opNode.getOperator() != Operator.PLUS)
				throw new CIVLSyntaxException(
						"CIVL requires the argument of \\valid "
								+ "expression to be a limited form:\n"
								+ "ptr (+ range)?\n"
								+ "range can be either an integer-expression\n "
								+ "or has the form \"integer-expression .. integer-expression\"",
						opNode.getSource());
			pointer = opNode.getArgument(0);
			range = opNode.getArgument(1).copy();
			if (range.expressionKind() == ExpressionKind.REGULAR_RANGE) {
				Triple<ExpressionNode, ExpressionNode, ExpressionNode> tri = decomposeRange(
						nf, (RegularRangeNode) range);

				offset = tri.first; // low
				extent = tri.third; // high - low + 1
			} else
				extent = nf.newIntegerConstantNode(range.getSource(), "1");
		} else {
			pointer = arg;
			extent = nf.newIntegerConstantNode(valid.getSource(), "1");
		}
		return new Triple<>(pointer, offset, extent);
	}

	private static ExpressionNode conjunct(ASTFactory af,
			List<ExpressionNode> exprs) {
		Iterator<ExpressionNode> iter = exprs.iterator();
		ExpressionNode result = null;
		Source source = null;
		TokenFactory tf = af.getTokenFactory();
		NodeFactory nf = af.getNodeFactory();

		while (iter.hasNext()) {
			ExpressionNode expr = iter.next();

			source = source != null
					? tf.join(source, expr.getSource())
					: expr.getSource();
			result = result != null
					? nf.newOperatorNode(source, Operator.LAND, expr.copy(),
							result)
					: expr.copy();
		}
		return result;
	}
}