SideEffectRemover.java

package edu.udel.cis.vsl.tass.ast.impl;

import edu.udel.cis.vsl.tass.ast.IF.ASTFactoryIF;
import edu.udel.cis.vsl.tass.ast.IF.ASTNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.AbstractSyntaxTreeIF;
import edu.udel.cis.vsl.tass.ast.IF.GlobalScopeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.IdentifierNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.SequenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.SideEffectRemoverIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.FunctionDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.LocalVariableDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.VariableDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.AssignmentNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.ExpressionNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.FunctionInvocationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.FunctionReferenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.IncrementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.OperatorNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.OperatorNodeIF.AST_OPERATOR;
import edu.udel.cis.vsl.tass.ast.IF.expression.SideEffectExpressionNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.VariableReferenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.AssertStatementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.BlockNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.ForLoopNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.LoopNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.StatementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.BooleanTypeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.IntegerTypeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.TypeNodeIF;
import edu.udel.cis.vsl.tass.model.IF.SyntaxException;

public class SideEffectRemover implements SideEffectRemoverIF {

	int tempIndex;
	String tempVarPrefix = "TASSTempVariable";
	ASTFactoryIF factory;

	public SideEffectRemover() {
		tempIndex = 0;
	}

	@Override
	public void setFactory(ASTFactoryIF factory) {
		this.factory = factory;
	}

	@Override
	public void transform(AbstractSyntaxTreeIF ast) throws SyntaxException {
		factory = ast.factory();
		for (int i = 0; i < ast.rootNode().globalScopeNodes().numChildren(); i++) {
			GlobalScopeNodeIF child = (GlobalScopeNodeIF) ast.rootNode()
					.globalScopeNodes().child(i);
			// Only need to worry about side effects in function bodies.
			if (child instanceof FunctionDeclarationNodeIF) {
				processBody(((FunctionDeclarationNodeIF) child).body());
				// ((FunctionDeclarationNodeIF)
				// ast.rootNode().globalScopeNodes().child(i)).setBody(((FunctionDeclarationNodeIF)
				// child).body());
			}
		}
	}

	/** Process the body of a loop or function. */
	private void processBody(BlockNodeIF body) {
		SequenceNodeIF<StatementNodeIF> statements;

		// System and abstract functions don't have bodies
		if (body == null) {
			return;
		}
		statements = body.statements();
		for (int i = 0; i < statements.numChildren(); i++) {
			StatementNodeIF statement = statements.getSequenceChild(i);
			if (hasSideEffects(statement)) {
				statement = removeSideEffects(statement);
				body.statements().setChild(i, statement);
			}
		}
	}

	@Override
	public String name() {
		return "sideEffectRemover";
	}

	@Override
	public StatementNodeIF removeSideEffects(StatementNodeIF statement) {
		if (statement instanceof AssertStatementNodeIF) {
			return processAssert((AssertStatementNodeIF) statement);
		} else if (statement instanceof OperatorNodeIF) {
			return processOperator((OperatorNodeIF) statement);
		} else if (statement instanceof AssignmentNodeIF) {
			return processAssign((AssignmentNodeIF) statement);
		} else if (statement instanceof FunctionInvocationNodeIF) {
			// A standalone function invocation is fine, even though it's a
			// SideEffectExpressionIF.
			return statement;
		} else if (statement instanceof IncrementNodeIF) {
			return processIncrement((IncrementNodeIF) statement);
		} else if (statement instanceof ForLoopNodeIF) {
			return processForLoop((ForLoopNodeIF) statement);
		} else if (statement instanceof LoopNodeIF) {
			return processLoop((LoopNodeIF) statement);
		}
		return null;
	}

	/**
	 * Loop nodes can contain side effects in their body, but each statement in
	 * the body must have side effects removed.
	 */
	private StatementNodeIF processLoop(LoopNodeIF statement) {
		if (statement.body() instanceof BlockNodeIF) {
			processBody((BlockNodeIF) statement.body());
		} else {
			statement.setBody(removeSideEffects(statement.body()));
		}
		return statement;
	}

	/**
	 * A for loop with side effects in the initializer or increment needs to
	 * have them handled. Otherwise, it is like a regular loop.
	 */
	private StatementNodeIF processForLoop(ForLoopNodeIF statement) {
		if (hasSideEffects(statement.initializer())) {
			statement
					.setInitializer(removeSideEffects(statement.initializer()));
		}
		if (hasSideEffects(statement.incrementer())) {
			statement
					.setIncrementer(removeSideEffects(statement.incrementer()));
		}
		if (hasSideEffects(statement.condition())) {
			// TODO: Figure out what to do here.
			// Removing side effects usually returns a statement.
			// But a condition can only be an expression.
			// However, maybe a condition should be an expression statement
			// once expressions are fixed to be no longer statements?
			removeSideEffects(statement.condition());
		}
		return processLoop(statement);
	}

	/**
	 * This takes a pre- or post- increment or decrement, and replaces it with a
	 * block of statements. Other parts of SideEffectRemover assume that the
	 * left hand side of the last statement in a new side-effect free block is
	 * the expression needed for evaluation in the containing expression. Thus
	 * pre- and post- increments need to be handled in different ways. For
	 * pre-increments, the block will have a statement for the increment and
	 * then the new incremented value. For post-increments, a temporary variable
	 * will be introduced to store the original value, the value will be
	 * incremented, and the temporary variable will be given again.
	 */
	private StatementNodeIF processIncrement(IncrementNodeIF increment) {
		IdentifierNodeIF id = factory.identifierNode(tempVarPrefix
				+ tempIndex++);
		SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
				.sequenceNode(VariableDeclarationNodeIF.class);
		SequenceNodeIF<StatementNodeIF> statements = factory
				.sequenceNode(StatementNodeIF.class);
		IntegerTypeNodeIF type = factory.integerTypeNode();
		BlockNodeIF block = factory.blockNode(variables, statements);
		LocalVariableDeclarationNodeIF tempVariable = factory
				.localVariableDeclarationNode(id, type, block);
		VariableReferenceNodeIF tempReference = factory.variableReferenceNode(
				tempVariable, id);
		OperatorNodeIF incrementOperation;

		if (increment.increment()) {
			incrementOperation = factory.operatorNode(AST_OPERATOR.ADD,
					increment.lhs(), factory.integerLiteralNode(null, type, 1));
		} else {
			incrementOperation = factory.operatorNode(AST_OPERATOR.SUBTRACT,
					increment.lhs(), factory.integerLiteralNode(null, type, 1));
		}
		variables.addSequenceChild(tempVariable);
		if (increment.prefix()) {
			statements.addSequenceChild(factory.assignmentNode(increment.lhs(),
					incrementOperation));
			statements.addSequenceChild(factory.assignmentNode(tempReference,
					increment.lhs()));
		} else {
			statements.addSequenceChild(factory.assignmentNode(tempReference,
					increment.lhs()));
			statements.addSequenceChild(factory.assignmentNode(increment.lhs(),
					incrementOperation));
			statements.addSequenceChild(factory.assignmentNode(tempReference,
					tempReference));
		}
		block.setVariables(variables);
		block.setStatements(statements);
		return block;
	}

	private StatementNodeIF processAssert(AssertStatementNodeIF statement) {
		IdentifierNodeIF id = factory.identifierNode(tempVarPrefix
				+ tempIndex++);
		SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
				.sequenceNode(VariableDeclarationNodeIF.class);
		SequenceNodeIF<StatementNodeIF> statements = factory
				.sequenceNode(StatementNodeIF.class);
		BlockNodeIF block = factory.blockNode(variables, statements);
		BooleanTypeNodeIF type = (BooleanTypeNodeIF) factory.booleanTypeNode();
		LocalVariableDeclarationNodeIF tempVariable = factory
				.localVariableDeclarationNode(id, type, block);
		VariableReferenceNodeIF tempReference = factory.variableReferenceNode(
				tempVariable, id);
		StatementNodeIF assignment = factory.assignmentNode(tempReference,
				statement.predicate());
		AssertStatementNodeIF assertStatement = factory
				.assertStatementNode(tempReference);

		variables.addSequenceChild(tempVariable);
		// Remove side effects from the assignment. This may return an
		// AssignmentNodeIF or a BlockNodeIF. We need to treat these
		// differently.
		assignment = removeSideEffects(assignment);
		if (assignment instanceof BlockNodeIF) {
			BlockNodeIF innerBlock = (BlockNodeIF) assignment;
			for (int i = 0; i < innerBlock.variables().numChildren(); i++) {
				variables.addSequenceChild(innerBlock.variables()
						.getSequenceChild(i));
			}
			statements = innerBlock.statements();
		} else {
			statements.addSequenceChild(assignment);
		}
		statements.addSequenceChild(assertStatement);
		block.setVariables(variables);
		block.setStatements(statements);
		return block;
	}

	private StatementNodeIF processOperator(OperatorNodeIF operator) {
		SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
				.sequenceNode(VariableDeclarationNodeIF.class);
		SequenceNodeIF<StatementNodeIF> statements = factory
				.sequenceNode(StatementNodeIF.class);
		BlockNodeIF block = factory.blockNode(variables, statements);
		IdentifierNodeIF id0, id1;
		VariableDeclarationNodeIF decl0, decl1;
		VariableReferenceNodeIF ref0, ref1;
		AssignmentNodeIF assign0, assign1;
		OperatorNodeIF newOperator;
		if (operator.numChildren() == 2) {

		} else if (operator.numChildren() == 3) {
			id0 = factory.identifierNode(tempVarPrefix + tempIndex++);
			id1 = factory.identifierNode(tempVarPrefix + tempIndex++);
			// TODO: determine types by recursively examining expressions
			if (operator.getOperator() == OperatorNodeIF.AST_OPERATOR.EQUALS) {
				TypeNodeIF type = factory.integerTypeNode();
				decl0 = factory.localVariableDeclarationNode(id0, type, block);
				ref0 = factory.variableReferenceNode(decl0, decl0.identifier());
				assign0 = factory.assignmentNode(ref0, operator.getArgument(0));
				decl1 = factory.localVariableDeclarationNode(id1, type, block);
				ref1 = factory.variableReferenceNode(decl1, decl1.identifier());
				assign1 = factory.assignmentNode(ref1, operator.getArgument(1));
				variables.addSequenceChild(decl0);
				variables.addSequenceChild(decl1);
				statements.addSequenceChild(assign0);
				statements.addSequenceChild(assign1);
				newOperator = factory.operatorNode(operator.getOperator(),
						ref0, ref1);
				statements.addSequenceChild(newOperator);
			}

		} else if (operator.numChildren() == 4) {
			// Conditional expression.
			// First check if this is actually an assert
			ExpressionNodeIF argument2 = operator.getArgument(2);
			// TODO: Handle expressions that evaluate to function references but
			// aren't necessarily themselves function references
			if (argument2 instanceof FunctionInvocationNodeIF
					&& ((FunctionReferenceNodeIF) ((FunctionInvocationNodeIF) argument2)
							.function()).referent().identifier().name()
							.equals("__assert__")) {
				StatementNodeIF assertStatement = factory
						.assertStatementNode(operator.getArgument(0));
				if (hasSideEffects(assertStatement)) {
					assertStatement = removeSideEffects(assertStatement);
				}
				return assertStatement;
			} else {
				// TODO: check side effects for general (non-assert) conditional
				// statements.
			}
		}
		block.setVariables(variables);
		block.setStatements(statements);
		block.setSource(operator.getSource());
		return block;
	}

	private StatementNodeIF processAssign(AssignmentNodeIF assignment) {
		BlockNodeIF block;
		SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
				.sequenceNode(VariableDeclarationNodeIF.class);
		SequenceNodeIF<StatementNodeIF> statements = factory
				.sequenceNode(StatementNodeIF.class);
		AssignmentNodeIF newAssignment;
		ExpressionNodeIF finalExpression;

		// The right hand side can be a function invocation (even though that
		// has side effects)
		if (assignment.rhs() instanceof FunctionInvocationNodeIF) {
			return assignment;
		}
		// All assignments will end up being processed because they are
		// instances of SideEffectExpressionNodeIF. If the right hand side is
		// side effect free, just return.
		if (!hasSideEffects(assignment.rhs())) {
			return assignment;
		}
		// Otherwise, remove side effects from the right hand side. This will
		// result in a block, with the last statement in the block being the
		// side effect free form of the predicate. Merge that block into this
		// block.
		block = (BlockNodeIF) removeSideEffects(assignment.rhs());
		variables = block.variables();
		statements = block.statements();
		finalExpression = (ExpressionNodeIF) statements.child(statements
				.numChildren() - 1);
		newAssignment = factory.assignmentNode(assignment.lhs(),
				finalExpression);
		statements.setChild(statements.numChildren() - 1, newAssignment);
		block.setVariables(variables);
		block.setStatements(statements);
		return block;
	}

	@Override
	public boolean hasSideEffects(StatementNodeIF statement) {
		boolean result = false;
		if (statement instanceof SideEffectExpressionNodeIF) {
			return true;
		} else {
			for (int i = 0; i < statement.numChildren(); i++) {
				ASTNodeIF child = statement.child(i);
				if (child instanceof StatementNodeIF) {
					result = result || hasSideEffects((StatementNodeIF) child);
				}
				if (result) {
					break;
				}
			}
		}
		return result;
	}

}