LogicFunctionTransformer.java

package dev.civl.mc.transform.common;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;

import dev.civl.abc.ast.entity.IF.Entity;
import dev.civl.abc.ast.entity.IF.Entity.EntityKind;
import dev.civl.abc.ast.entity.IF.Function;
import dev.civl.abc.ast.node.IF.ASTNode;
import dev.civl.abc.ast.node.IF.ASTNode.NodeKind;
import dev.civl.abc.ast.node.IF.IdentifierNode;
import dev.civl.abc.ast.node.IF.NodeFactory;
import dev.civl.abc.ast.node.IF.PairNode;
import dev.civl.abc.ast.node.IF.SequenceNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDefinitionNode;
import dev.civl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode.ExpressionKind;
import dev.civl.abc.ast.node.IF.expression.FunctionCallNode;
import dev.civl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode.Operator;
import dev.civl.abc.ast.node.IF.expression.QuantifiedExpressionNode;
import dev.civl.abc.ast.node.IF.type.ArrayTypeNode;
import dev.civl.abc.ast.node.IF.type.FunctionTypeNode;
import dev.civl.abc.ast.node.IF.type.PointerTypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode.TypeNodeKind;
import dev.civl.abc.ast.type.IF.FunctionType;
import dev.civl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import dev.civl.abc.ast.type.IF.Type.TypeKind;
import dev.civl.abc.token.IF.Source;
import dev.civl.abc.token.IF.SyntaxException;
import dev.civl.abc.token.IF.TokenFactory;
import dev.civl.abc.util.IF.Pair;
import dev.civl.mc.model.IF.CIVLUnimplementedFeatureException;

/**
 * <p>
 * A logic function transformer helps transforming 1) a logic function
 * definition to a form that can be easily evaluated in a stateless way. 2) a
 * logic function call expression to the form corresponds to the change of its
 * definition.
 * </p>
 * 
 * <p>
 * A logic function definition is suppose to be stateless. Objects referred by
 * the definition must be declared as the formal parameters, except for
 * pointers. Dereferencing pointers needs states. However, instead of passing
 * pointers to logic functions, passing array values which are referred by the
 * pointers can help make the logic function definition being stateless. This
 * idea brings "valid negative array element indices" in the definition since a
 * pointer can points to the middle of an array.
 * </p>
 * 
 * <p>
 * To get rid of "valid negative array element indices", this transformer
 * transforms a pointer <code>p</code> to a base address <code>q</code> and an
 * offset <code>oft</code> s.t. <code>q + offset = p</code> and
 * <code>q - 1</code> is invalid. Then any appearance of <code>p</code> in the
 * definition is replaced with <code>&q[oft]</code>. Except that for logic
 * function call expressions <code>f(..., e(p), ...)</code> will be transformed
 * to <code>f(..., e(&q[oft], 0, ...)</code>, where <code>e(p)</code> is an
 * expression of pointer type involving <code>p</code>.
 * </p>
 * 
 * <p>
 * For quantified expressions <code>Q</code> with a pointer type bound variable
 * <code>p</code>: <code>
 * Q := FORALL p. P(p) will be transformed to FORALL q. FORALL oft. P(&q[oft]).
 * Q := EXISTS p. P(p) will be transformed to EXISTS q. EXISTS oft. P(&q[oft]).
 * </code>
 * </p>
 * 
 * 
 * @author ziqing
 *
 */
public class LogicFunctionTransformer {

	/**
	 * Name prefix for the generated extra offset argument
	 */
	private static String offset_name_prefix = "_oft_";

	/**
	 * Name of the system function which maps a pointer p to another pointer q
	 * such that there exists a interger offset that q + offset = p. And q - 1
	 * is invalid.
	 */
	private static String array_base_address_of = "$array_base_address_of";

	/**
	 * A reference to {@link NodeFactory}
	 */
	private NodeFactory nodeFactory;

	/**
	 * A reference to {@link TokenFactory}
	 */
	private TokenFactory tokenFactory;

	private class Pointer {
		final IdentifierNode baseAddr;
		final IdentifierNode offset;

		Pointer(IdentifierNode baseAddr) {
			this.baseAddr = baseAddr;
			this.offset = nodeFactory.newIdentifierNode(baseAddr.getSource(),
					offset_name_prefix + baseAddr.name());
			assert this.baseAddr != null && this.offset != null;
		}
	}

	public LogicFunctionTransformer(NodeFactory nodeFactory,
			TokenFactory tokenFactory) {
		this.nodeFactory = nodeFactory;
		this.tokenFactory = tokenFactory;
	}

	/**
	 * Transforms a logic function definition to a form that is easily for
	 * back-end to evaluate it in a stateless way. see
	 * {@link LogicFunctionTransformer}.
	 * 
	 * @param type
	 *            the function type of the logic function
	 * @param expression
	 *            A logic function definition
	 * @return transformed definition
	 * @throws SyntaxException
	 */
	public void transformDefinition(FunctionDeclarationNode logicFunctionDecl)
			throws SyntaxException {
		if (!logicFunctionDecl.isLogicFunction())
			return;

		// System.out.println(
		// "Transform " + logicFunctionDecl.prettyRepresentation());

		FunctionTypeNode typeNode = (FunctionTypeNode) logicFunctionDecl
				.getTypeNode();
		List<VariableDeclarationNode> newParams = new LinkedList<>();
		List<Pointer> pointerParams = new LinkedList<>();
		Source newParamSource = null;

		for (VariableDeclarationNode formal : typeNode.getParameters()) {
			if (formal.getTypeNode().getType().kind() == TypeKind.POINTER) {
				Pointer pointerParam = new Pointer(formal.getIdentifier());
				supportedFormalType(formal, formal.getTypeNode());
				newParams.add(formal.copy());
				newParams.add(nodeFactory.newVariableDeclarationNode(
						formal.getSource(), pointerParam.offset.copy(),
						nodeFactory.newBasicTypeNode(formal.getSource(),
								BasicTypeKind.INT)));
				pointerParams.add(pointerParam);
			} else
				newParams.add(formal.copy());
			newParamSource = newParamSource == null
					? formal.getSource()
					: tokenFactory.join(newParamSource, formal.getSource());
		}
		newParamSource = newParamSource == null
				? logicFunctionDecl.getSource()
				: newParamSource;

		Stack<Pointer[]> pointersStack = new Stack<>();
		Pointer[] pointerArgs = new Pointer[pointerParams.size()];

		pointerParams.toArray(pointerArgs);
		pointersStack.push(pointerArgs);
		if (logicFunctionDecl.isDefinition()) {
			ExpressionNode definition = ((FunctionDefinitionNode) logicFunctionDecl)
					.getLogicDefinition();

			definition = tranformExpression(definition, pointersStack);

			// System.out.println(" ==> " + definition.prettyRepresentation());
		}
		typeNode.getParameters().remove();
		typeNode.setParameters(nodeFactory.newSequenceNode(newParamSource,
				"logic function params", newParams));
		// System.out.println(" ==> " +
		// logicFunctionDecl.prettyRepresentation());
	}

	private ExpressionNode tranformExpression(ExpressionNode definition,
			Stack<Pointer[]> pointersStack) throws SyntaxException {
		ASTNode node = definition;
		ASTNode parent = node.parent();
		int childIdx = node.childIndex();
		List<Pair<ASTNode, ASTNode>> replacements = new LinkedList<>();

		node.remove();
		do {
			if (node.nodeKind() == NodeKind.EXPRESSION) {
				ExpressionNode expr = (ExpressionNode) node;

				switch (expr.expressionKind()) {
					case IDENTIFIER_EXPRESSION :
						expr = transformIdentifierExpressionWorker(
								(IdentifierExpressionNode) expr, pointersStack);
						replacements.add(new Pair<>(node, expr));
						break;
					case FUNCTION_CALL :
						transformFuncCallExpressionWorker(
								(FunctionCallNode) expr, pointersStack);
						// changing children of FunctionCallNode, no need for
						// replacing
						break;
					case QUANTIFIED_EXPRESSION :
						expr = transformQuantifiedExpressionWorker(
								(QuantifiedExpressionNode) expr, pointersStack);
						replacements.add(new Pair<>(node, expr));
						break;
					default :
				}
			}
		} while ((node = node.nextDFS()) != null);
		parent.setChild(childIdx, definition);
		// do replacements at one time:
		for (Pair<ASTNode, ASTNode> replace : replacements) {
			ASTNode replaceParent = replace.left.parent();
			int replaceChildIdx = replace.left.childIndex();

			replace.left.remove();
			replaceParent.setChild(replaceChildIdx, replace.right);
		}
		return definition;
	}

	/**
	 * transforms <code>p</code> to <code>&q[offset}</code>,
	 * <code>q + offset = p</code>.
	 * 
	 * @return new expression node can be used to replace the given identifier
	 *         expression.
	 */
	private ExpressionNode transformIdentifierExpressionWorker(
			IdentifierExpressionNode identifierExpr,
			Stack<Pointer[]> pointersStack) {
		Pointer matched = match(identifierExpr.getIdentifier(), pointersStack);

		if (matched == null)
			return identifierExpr;
		else {
			// p -> &q[oft]: //TODO: think about what type of pointer is correct
			Source source = identifierExpr.getSource();
			ExpressionNode transformed = nodeFactory.newOperatorNode(source,
					Operator.SUBSCRIPT,
					nodeFactory.newIdentifierExpressionNode(source,
							matched.baseAddr.copy()),
					nodeFactory.newIdentifierExpressionNode(source,
							matched.offset.copy()));

			return nodeFactory.newOperatorNode(source, Operator.ADDRESSOF,
					transformed);
		}
	}

	/**
	 * transforms <code>f(..., p, ...)</code> where p is a pointer to
	 * <code>f(..., &q[offset], 0, ...)</code>
	 */
	private void transformFuncCallExpressionWorker(FunctionCallNode callNode,
			Stack<Pointer[]> pointersStack) throws SyntaxException {
		if (callNode.getFunction()
				.expressionKind() != ExpressionKind.IDENTIFIER_EXPRESSION)
			return;

		List<ExpressionNode> newArgs = new LinkedList<>();
		SequenceNode<ExpressionNode> oldArgs = callNode.getArguments();

		for (ExpressionNode arg : callNode.getArguments()) {
			if (arg.getType().kind() == TypeKind.POINTER) {
				arg = tranformExpression(arg, pointersStack);
				newArgs.add(arg.copy());
				newArgs.add(nodeFactory.newIntegerConstantNode(arg.getSource(),
						"0"));
			} else {
				arg.remove();
				newArgs.add(arg);
			}
		}
		oldArgs.remove();
		callNode.setArguments(nodeFactory.newSequenceNode(oldArgs.getSource(),
				"logic-func-args", newArgs));
		return;
	}

	/**
	 * transforms <code>Quantifier *p. P(p)</code> to
	 * <code>Quantifier *q, a. P(&q[a])</code>
	 * 
	 * @return new expression node can be used to replace the given quantified
	 *         expression.
	 */
	private ExpressionNode transformQuantifiedExpressionWorker(
			QuantifiedExpressionNode quantNode, Stack<Pointer[]> pointersStack)
			throws SyntaxException {
		List<Pointer> pointers = new LinkedList<>();

		for (PairNode<SequenceNode<VariableDeclarationNode>, ExpressionNode> bvs : quantNode
				.boundVariableList()) {
			for (VariableDeclarationNode bv : bvs.getLeft()) {
				// TODO: for now, assuming pointer type bound variables are
				// never appear in restrictions:
				if (bv.getTypeNode().getType().kind() == TypeKind.POINTER) {
					Pointer pointer = new Pointer(bv.getIdentifier());

					pointers.add(pointer);
				}
			}
		}

		Pointer pointersArray[] = new Pointer[pointers.size()];
		ExpressionNode pred;

		if (pointersArray.length <= 0)
			return quantNode;
		pointers.toArray(pointersArray);
		pointersStack.push(pointersArray);
		pred = tranformExpression(quantNode.expression(), pointersStack);
		pointersStack.pop();
		pred = nodeFactory.newQuantifiedExpressionNode(quantNode.getSource(),
				quantNode.quantifier(), quantNode.boundVariableList().copy(),
				quantNode.restriction().copy(), pred.copy(),
				quantNode.intervalSequence().copy());

		List<VariableDeclarationNode> offsets_bv = new LinkedList<>();
		SequenceNode<PairNode<SequenceNode<VariableDeclarationNode>, ExpressionNode>> offsets;
		Source offsetsSource = null;

		for (Pointer ptr : pointersArray) {
			offsets_bv.add(nodeFactory.newVariableDeclarationNode(
					ptr.offset.getSource(), ptr.offset.copy(),
					nodeFactory.newBasicTypeNode(ptr.offset.getSource(),
							BasicTypeKind.INT)));
			offsetsSource = offsetsSource != null
					? tokenFactory.join(offsetsSource, ptr.offset.getSource())
					: ptr.offset.getSource();
		}
		assert offsetsSource != null;
		offsets = nodeFactory.newSequenceNode(offsetsSource,
				"bounded-offset-sequence",
				Arrays.asList(nodeFactory.newPairNode(offsetsSource,
						nodeFactory.newSequenceNode(offsetsSource,
								"bounded-offsets", offsets_bv),
						null)));
		return nodeFactory.newQuantifiedExpressionNode(pred.getSource(),
				quantNode.quantifier(), offsets, null, pred, null);
	}

	/**
	 * 
	 * @return true iff the given node is an identifier expression of one of the
	 *         pointer type arguments (or bound variable).
	 */
	private Pointer match(ASTNode node, Stack<Pointer[]> pointersStack) {
		if (node instanceof IdentifierNode) {
			IdentifierNode idNode = (IdentifierNode) node;
			for (Pointer[] ptrs : pointersStack)
				for (Pointer ptr : ptrs)
					if (ptr.baseAddr.name().equals(idNode.name()))
						return ptr;
		}
		return null;
	}

	/**
	 * <p>
	 * Transforms a logic function call, which is NOT in any logic function
	 * definition, to a form that corresponds to the change of its definition.
	 * see {@link #transformDefinition(FunctionType, ExpressionNode)}.
	 * </p>
	 * <p>
	 * A logic function call with pointer-type actual paramter <code>p</code>
	 * <code>f(..., p, ...)</code> will be transformed to
	 * <code>f(..., $array_base_address_of(p),  p - $array_base_address_of(p), ...)</code>
	 * .
	 * </p>
	 * 
	 * @param expression
	 *            a function call expression to a logic function
	 * @return transformed logic function call
	 */
	public void transformCall(FunctionCallNode expression) {
		ExpressionNode function = expression.getFunction();

		if (function.expressionKind() != ExpressionKind.IDENTIFIER_EXPRESSION)
			return;

		IdentifierExpressionNode funcIdent = (IdentifierExpressionNode) function;
		Entity entity = funcIdent.getIdentifier().getEntity();
		Function funcEntity;

		if (entity == null || entity.getEntityKind() != EntityKind.FUNCTION)
			return;
		funcEntity = (Function) entity;
		if (!funcEntity.isLogic())
			return;

		// System.out.println("Transform: " +
		// expression.prettyRepresentation());
		// transform arguments:
		SequenceNode<ExpressionNode> args = expression.getArguments();
		List<ExpressionNode> newArgs = new LinkedList<>();
		int idx = 0;
		FunctionDeclarationNode logicFuncDecl = (FunctionDeclarationNode) funcEntity
				.getFirstDeclaration();
		FunctionTypeNode funcType = (FunctionTypeNode) logicFuncDecl
				.getTypeNode();

		for (ExpressionNode arg : args) {
			if (arg.getType().kind() == TypeKind.POINTER) {
				newArgs.add(arrayBaseAddressOf(arg.copy()));
				newArgs.add(offsetToArrayBase(funcType.getParameters()
						.getSequenceChild(idx).getTypeNode(), arg.copy()));
				idx += 2;
			} else {
				arg.remove();
				newArgs.add(arg);
				idx++;
			}
		}
		args.remove();
		expression.setArguments(nodeFactory.newSequenceNode(args.getSource(),
				"logic-function arguments", newArgs));
		// System.out.println(" ==> " + expression.prettyRepresentation());
	}

	/**
	 * transforms <code>p</code> to <code>$array_base_address_of(p)</code>
	 */
	private ExpressionNode arrayBaseAddressOf(ExpressionNode pointer) {
		Source source = pointer.getSource();

		return nodeFactory.newFunctionCallNode(source,
				nodeFactory.newIdentifierExpressionNode(source,
						nodeFactory.newIdentifierNode(source,
								array_base_address_of)),
				Arrays.asList(pointer.copy()), null);
	}

	/**
	 * generating <code>p - (int *)$array_base_address_of(p)</code>
	 */
	private ExpressionNode offsetToArrayBase(TypeNode pointerTypeNode,
			ExpressionNode pointer) {
		Source source = pointer.getSource();
		return nodeFactory.newOperatorNode(source, Operator.MINUS,
				pointer.copy(), nodeFactory.newCastNode(pointer.getSource(),
						pointerTypeNode.copy(), arrayBaseAddressOf(pointer)));
	}

	/**
	 *
	 * @return true iff the given type node of the formal parameter represents a
	 *         non-pointer scalar type or a pointer to non-pointer scalar type
	 */
	private void supportedFormalType(ASTNode formal, TypeNode typeNode) {
		if (typeNode.kind() == TypeNodeKind.BASIC
				&& typeNode.kind() != TypeNodeKind.POINTER)
			return;
		if (typeNode.kind() == TypeNodeKind.POINTER) {
			TypeNode referredType = ((PointerTypeNode) typeNode)
					.referencedType();

			if (noPointerIn(referredType))
				return;
		}
		throw new CIVLUnimplementedFeatureException(
				"A formal parameter of logic function has non-scalar type,"
						+ " pointer to pointer type or pointer to non-scalar type.",
				formal.getSource());
	}

	/**
	 * @return true iff the given type node contains no sub-type which is a
	 *         pointer type. e.g. if the given type node represents an array of
	 *         int type, there is NO sub-type in it is a pointer type; if the
	 *         given type node represents an array of pointer to int, then there
	 *         IS sub-type in it is a pointer type.
	 */
	private boolean noPointerIn(TypeNode typeNode) {
		if (typeNode.kind() == TypeNodeKind.POINTER)
			return false;
		if (typeNode.kind() == TypeNodeKind.BASIC)
			return true;
		if (typeNode.kind() == TypeNodeKind.ARRAY) {
			return noPointerIn(((ArrayTypeNode) typeNode).getElementType());
		}
		return false;
	}
}