SideEffectRemover.java
package edu.udel.cis.vsl.tass.ast.impl;
import edu.udel.cis.vsl.tass.ast.IF.ASTFactoryIF;
import edu.udel.cis.vsl.tass.ast.IF.ASTNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.AbstractSyntaxTreeIF;
import edu.udel.cis.vsl.tass.ast.IF.GlobalScopeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.IdentifierNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.SequenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.SideEffectRemoverIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.FunctionDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.LocalVariableDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.declaration.VariableDeclarationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.AssignmentNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.ExpressionNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.FunctionInvocationNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.FunctionReferenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.IncrementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.OperatorNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.OperatorNodeIF.AST_OPERATOR;
import edu.udel.cis.vsl.tass.ast.IF.expression.SideEffectExpressionNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.expression.VariableReferenceNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.AssertStatementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.BlockNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.ForLoopNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.LoopNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.statement.StatementNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.BooleanTypeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.IntegerTypeNodeIF;
import edu.udel.cis.vsl.tass.ast.IF.type.TypeNodeIF;
import edu.udel.cis.vsl.tass.model.IF.SyntaxException;
public class SideEffectRemover implements SideEffectRemoverIF {
int tempIndex;
String tempVarPrefix = "TASSTempVariable";
ASTFactoryIF factory;
public SideEffectRemover() {
tempIndex = 0;
}
@Override
public void setFactory(ASTFactoryIF factory) {
this.factory = factory;
}
@Override
public void transform(AbstractSyntaxTreeIF ast) throws SyntaxException {
factory = ast.factory();
for (int i = 0; i < ast.rootNode().globalScopeNodes().numChildren(); i++) {
GlobalScopeNodeIF child = (GlobalScopeNodeIF) ast.rootNode()
.globalScopeNodes().child(i);
// Only need to worry about side effects in function bodies.
if (child instanceof FunctionDeclarationNodeIF) {
processBody(((FunctionDeclarationNodeIF) child).body());
// ((FunctionDeclarationNodeIF)
// ast.rootNode().globalScopeNodes().child(i)).setBody(((FunctionDeclarationNodeIF)
// child).body());
}
}
}
/** Process the body of a loop or function. */
private void processBody(BlockNodeIF body) {
SequenceNodeIF<StatementNodeIF> statements;
// System and abstract functions don't have bodies
if (body == null) {
return;
}
statements = body.statements();
for (int i = 0; i < statements.numChildren(); i++) {
StatementNodeIF statement = statements.getSequenceChild(i);
if (hasSideEffects(statement)) {
statement = removeSideEffects(statement);
body.statements().setChild(i, statement);
}
}
}
@Override
public String name() {
return "sideEffectRemover";
}
@Override
public StatementNodeIF removeSideEffects(StatementNodeIF statement) {
if (statement instanceof AssertStatementNodeIF) {
return processAssert((AssertStatementNodeIF) statement);
} else if (statement instanceof OperatorNodeIF) {
return processOperator((OperatorNodeIF) statement);
} else if (statement instanceof AssignmentNodeIF) {
return processAssign((AssignmentNodeIF) statement);
} else if (statement instanceof FunctionInvocationNodeIF) {
// A standalone function invocation is fine, even though it's a
// SideEffectExpressionIF.
return statement;
} else if (statement instanceof IncrementNodeIF) {
return processIncrement((IncrementNodeIF) statement);
} else if (statement instanceof ForLoopNodeIF) {
return processForLoop((ForLoopNodeIF) statement);
} else if (statement instanceof LoopNodeIF) {
return processLoop((LoopNodeIF) statement);
}
return null;
}
/**
* Loop nodes can contain side effects in their body, but each statement in
* the body must have side effects removed.
*/
private StatementNodeIF processLoop(LoopNodeIF statement) {
if (statement.body() instanceof BlockNodeIF) {
processBody((BlockNodeIF) statement.body());
} else {
statement.setBody(removeSideEffects(statement.body()));
}
return statement;
}
/**
* A for loop with side effects in the initializer or increment needs to
* have them handled. Otherwise, it is like a regular loop.
*/
private StatementNodeIF processForLoop(ForLoopNodeIF statement) {
if (hasSideEffects(statement.initializer())) {
statement
.setInitializer(removeSideEffects(statement.initializer()));
}
if (hasSideEffects(statement.incrementer())) {
statement
.setIncrementer(removeSideEffects(statement.incrementer()));
}
if (hasSideEffects(statement.condition())) {
// TODO: Figure out what to do here.
// Removing side effects usually returns a statement.
// But a condition can only be an expression.
// However, maybe a condition should be an expression statement
// once expressions are fixed to be no longer statements?
removeSideEffects(statement.condition());
}
return processLoop(statement);
}
/**
* This takes a pre- or post- increment or decrement, and replaces it with a
* block of statements. Other parts of SideEffectRemover assume that the
* left hand side of the last statement in a new side-effect free block is
* the expression needed for evaluation in the containing expression. Thus
* pre- and post- increments need to be handled in different ways. For
* pre-increments, the block will have a statement for the increment and
* then the new incremented value. For post-increments, a temporary variable
* will be introduced to store the original value, the value will be
* incremented, and the temporary variable will be given again.
*/
private StatementNodeIF processIncrement(IncrementNodeIF increment) {
IdentifierNodeIF id = factory.identifierNode(tempVarPrefix
+ tempIndex++);
SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
.sequenceNode(VariableDeclarationNodeIF.class);
SequenceNodeIF<StatementNodeIF> statements = factory
.sequenceNode(StatementNodeIF.class);
IntegerTypeNodeIF type = factory.integerTypeNode();
BlockNodeIF block = factory.blockNode(variables, statements);
LocalVariableDeclarationNodeIF tempVariable = factory
.localVariableDeclarationNode(id, type, block);
VariableReferenceNodeIF tempReference = factory.variableReferenceNode(
tempVariable, id);
OperatorNodeIF incrementOperation;
if (increment.increment()) {
incrementOperation = factory.operatorNode(AST_OPERATOR.ADD,
increment.lhs(), factory.integerLiteralNode(null, type, 1));
} else {
incrementOperation = factory.operatorNode(AST_OPERATOR.SUBTRACT,
increment.lhs(), factory.integerLiteralNode(null, type, 1));
}
variables.addSequenceChild(tempVariable);
if (increment.prefix()) {
statements.addSequenceChild(factory.assignmentNode(increment.lhs(),
incrementOperation));
statements.addSequenceChild(factory.assignmentNode(tempReference,
increment.lhs()));
} else {
statements.addSequenceChild(factory.assignmentNode(tempReference,
increment.lhs()));
statements.addSequenceChild(factory.assignmentNode(increment.lhs(),
incrementOperation));
statements.addSequenceChild(factory.assignmentNode(tempReference,
tempReference));
}
block.setVariables(variables);
block.setStatements(statements);
return block;
}
private StatementNodeIF processAssert(AssertStatementNodeIF statement) {
IdentifierNodeIF id = factory.identifierNode(tempVarPrefix
+ tempIndex++);
SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
.sequenceNode(VariableDeclarationNodeIF.class);
SequenceNodeIF<StatementNodeIF> statements = factory
.sequenceNode(StatementNodeIF.class);
BlockNodeIF block = factory.blockNode(variables, statements);
BooleanTypeNodeIF type = (BooleanTypeNodeIF) factory.booleanTypeNode();
LocalVariableDeclarationNodeIF tempVariable = factory
.localVariableDeclarationNode(id, type, block);
VariableReferenceNodeIF tempReference = factory.variableReferenceNode(
tempVariable, id);
StatementNodeIF assignment = factory.assignmentNode(tempReference,
statement.predicate());
AssertStatementNodeIF assertStatement = factory
.assertStatementNode(tempReference);
variables.addSequenceChild(tempVariable);
// Remove side effects from the assignment. This may return an
// AssignmentNodeIF or a BlockNodeIF. We need to treat these
// differently.
assignment = removeSideEffects(assignment);
if (assignment instanceof BlockNodeIF) {
BlockNodeIF innerBlock = (BlockNodeIF) assignment;
for (int i = 0; i < innerBlock.variables().numChildren(); i++) {
variables.addSequenceChild(innerBlock.variables()
.getSequenceChild(i));
}
statements = innerBlock.statements();
} else {
statements.addSequenceChild(assignment);
}
statements.addSequenceChild(assertStatement);
block.setVariables(variables);
block.setStatements(statements);
return block;
}
private StatementNodeIF processOperator(OperatorNodeIF operator) {
SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
.sequenceNode(VariableDeclarationNodeIF.class);
SequenceNodeIF<StatementNodeIF> statements = factory
.sequenceNode(StatementNodeIF.class);
BlockNodeIF block = factory.blockNode(variables, statements);
IdentifierNodeIF id0, id1;
VariableDeclarationNodeIF decl0, decl1;
VariableReferenceNodeIF ref0, ref1;
AssignmentNodeIF assign0, assign1;
OperatorNodeIF newOperator;
if (operator.numChildren() == 2) {
} else if (operator.numChildren() == 3) {
id0 = factory.identifierNode(tempVarPrefix + tempIndex++);
id1 = factory.identifierNode(tempVarPrefix + tempIndex++);
// TODO: determine types by recursively examining expressions
if (operator.getOperator() == OperatorNodeIF.AST_OPERATOR.EQUALS) {
TypeNodeIF type = factory.integerTypeNode();
decl0 = factory.localVariableDeclarationNode(id0, type, block);
ref0 = factory.variableReferenceNode(decl0, decl0.identifier());
assign0 = factory.assignmentNode(ref0, operator.getArgument(0));
decl1 = factory.localVariableDeclarationNode(id1, type, block);
ref1 = factory.variableReferenceNode(decl1, decl1.identifier());
assign1 = factory.assignmentNode(ref1, operator.getArgument(1));
variables.addSequenceChild(decl0);
variables.addSequenceChild(decl1);
statements.addSequenceChild(assign0);
statements.addSequenceChild(assign1);
newOperator = factory.operatorNode(operator.getOperator(),
ref0, ref1);
statements.addSequenceChild(newOperator);
}
} else if (operator.numChildren() == 4) {
// Conditional expression.
// First check if this is actually an assert
ExpressionNodeIF argument2 = operator.getArgument(2);
// TODO: Handle expressions that evaluate to function references but
// aren't necessarily themselves function references
if (argument2 instanceof FunctionInvocationNodeIF
&& ((FunctionReferenceNodeIF) ((FunctionInvocationNodeIF) argument2)
.function()).referent().identifier().name()
.equals("__assert__")) {
StatementNodeIF assertStatement = factory
.assertStatementNode(operator.getArgument(0));
if (hasSideEffects(assertStatement)) {
assertStatement = removeSideEffects(assertStatement);
}
return assertStatement;
} else {
// TODO: check side effects for general (non-assert) conditional
// statements.
}
}
block.setVariables(variables);
block.setStatements(statements);
block.setSource(operator.getSource());
return block;
}
private StatementNodeIF processAssign(AssignmentNodeIF assignment) {
BlockNodeIF block;
SequenceNodeIF<VariableDeclarationNodeIF> variables = factory
.sequenceNode(VariableDeclarationNodeIF.class);
SequenceNodeIF<StatementNodeIF> statements = factory
.sequenceNode(StatementNodeIF.class);
AssignmentNodeIF newAssignment;
ExpressionNodeIF finalExpression;
// The right hand side can be a function invocation (even though that
// has side effects)
if (assignment.rhs() instanceof FunctionInvocationNodeIF) {
return assignment;
}
// All assignments will end up being processed because they are
// instances of SideEffectExpressionNodeIF. If the right hand side is
// side effect free, just return.
if (!hasSideEffects(assignment.rhs())) {
return assignment;
}
// Otherwise, remove side effects from the right hand side. This will
// result in a block, with the last statement in the block being the
// side effect free form of the predicate. Merge that block into this
// block.
block = (BlockNodeIF) removeSideEffects(assignment.rhs());
variables = block.variables();
statements = block.statements();
finalExpression = (ExpressionNodeIF) statements.child(statements
.numChildren() - 1);
newAssignment = factory.assignmentNode(assignment.lhs(),
finalExpression);
statements.setChild(statements.numChildren() - 1, newAssignment);
block.setVariables(variables);
block.setStatements(statements);
return block;
}
@Override
public boolean hasSideEffects(StatementNodeIF statement) {
boolean result = false;
if (statement instanceof SideEffectExpressionNodeIF) {
return true;
} else {
for (int i = 0; i < statement.numChildren(); i++) {
ASTNodeIF child = statement.child(i);
if (child instanceof StatementNodeIF) {
result = result || hasSideEffects((StatementNodeIF) child);
}
if (result) {
break;
}
}
}
return result;
}
}