DirectingWorker.java

package dev.civl.mc.transform.common;

import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
import dev.civl.abc.ast.IF.AST;
import dev.civl.abc.ast.IF.ASTFactory;
import dev.civl.abc.ast.node.IF.ASTNode;
import dev.civl.abc.ast.node.IF.IdentifierNode;
import dev.civl.abc.ast.node.IF.PairNode;
import dev.civl.abc.ast.node.IF.SequenceNode;
import dev.civl.abc.ast.node.IF.acsl.ContractNode;
import dev.civl.abc.ast.node.IF.compound.CompoundInitializerNode;
import dev.civl.abc.ast.node.IF.compound.DesignationNode;
import dev.civl.abc.ast.node.IF.declaration.InitializerNode;
import dev.civl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode;
import dev.civl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import dev.civl.abc.ast.node.IF.expression.IntegerConstantNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode.Operator;
import dev.civl.abc.ast.node.IF.label.SwitchLabelNode;
import dev.civl.abc.ast.node.IF.statement.BlockItemNode;
import dev.civl.abc.ast.node.IF.statement.DeclarationListNode;
import dev.civl.abc.ast.node.IF.statement.ForLoopInitializerNode;
import dev.civl.abc.ast.node.IF.statement.ForLoopNode;
import dev.civl.abc.ast.node.IF.statement.IfNode;
import dev.civl.abc.ast.node.IF.statement.LabeledStatementNode;
import dev.civl.abc.ast.node.IF.statement.LoopNode;
import dev.civl.abc.ast.node.IF.statement.LoopNode.LoopKind;
import dev.civl.abc.ast.node.IF.statement.StatementNode;
import dev.civl.abc.ast.node.IF.statement.SwitchNode;
import dev.civl.abc.ast.node.IF.type.TypeNode;
import dev.civl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import dev.civl.abc.front.IF.CivlcTokenConstant;
import dev.civl.abc.front.IF.Preprocessor;
import dev.civl.abc.token.IF.Source;
import dev.civl.abc.token.IF.SyntaxException;
import dev.civl.mc.config.IF.CIVLConfiguration;
import dev.civl.mc.transform.IF.DirectingTransformer;

/**
 * This worker transforms branching instructions whose file name and source line number
 * match a "direction target" as follows:
 * 
 * replace 
 *    if ( Cond ) S
 * by
 *   $assume(Lbranch[LbranchIdx++] ? Cond : ! Cond; 
 *   if ( Cond ) S
 *
 * replace 
 *   for ( Init ; Cond ; Incr ) S
 * by
 *   Init; 
 *   while (1) { 
 *     $assume(Lbranch[LbranchIdx++] ? Cond : ! Cond; 
 *     if ( ! Cond ) break; 
 *     S;
 *     Inc;
 *   }
 *
 * replace 
 *   while ( Cond ) S 
 * by
 *   while (1) { 
 *     $assume(Lbranch[LbranchIdx++] ? Cond : ! Cond; 
 *     if ( ! Cond ) break; 
 *     S;
 *   }
 *
 * replace 
 *   do S while ( Cond )
 * by
 *   do { 
 *     S; 
 *     $assume(Lbranch[LbranchIdx++] ? Cond : ! Cond; 
 *   } while ( Cond ) 
 * 
 * TBD: need to support switch, need to make sure that we have access the right variables (e.g., via name)
 * 
 * @author dwyer
 * 
 */
public class DirectingWorker extends BaseWorker {

	private boolean debug = false;
	private CIVLConfiguration config;
	private String indexVarName;
	private String arrayVarName;
	
	private Set<Integer> directingLines;
	private ArrayList<Integer> directions;
	private String directingFile;
	

	public DirectingWorker(ASTFactory astFactory,
			CIVLConfiguration config) {
		super(DirectingTransformer.LONG_NAME, astFactory);
		this.identifierPrefix = "$direct_";
		this.config = config;
		this.indexVarName = identifierPrefix+"index";
		this.arrayVarName = identifierPrefix+"array";
		directingLines = new HashSet<Integer>();
		directions = new ArrayList<Integer>();
		directingFile = null;
	}

	@Override
	protected AST transformCore(AST unit) throws SyntaxException {
		String inputFile = config.directSymEx();
		assert inputFile != null : "Expected lines and directions file for directed symbolic execution";
		
		/**
		 * File format is:
		 *    name.c              // this is the file to instrument
		 *    lines 42 23 1 ...   // these are the lines to instrument
		 *    guide 0 1 1 0 ...   // these are the branch outcomes
		 */
		try(Scanner s = new Scanner(Paths.get(inputFile))) {
			directingFile = s.nextLine();
			s.next(); // absorb the "lines" token
			while (s.hasNextInt()) {
				directingLines.add(s.nextInt());
			}	
			s.next(); // absorb the "guide" token
			while (s.hasNextInt()) {
				directions.add(s.nextInt());
			}
		} catch (IOException e) {
			e.printStackTrace();
		}
		
		// Must include civlc.cvh to resolve $assume if it is not already present
		AST civlcAST = null;
		if (unit.getInternalOrExternalEntity("$assume") == null) {
			civlcAST = this.parseSystemLibrary(new File(
				Preprocessor.ABC_INCLUDE_PATH, "civlc.cvh"), EMPTY_MACRO_MAP);
		}

		SequenceNode<BlockItemNode> rootNode = unit.getRootNode();

		assert this.astFactory == unit.getASTFactory();
		assert this.nodeFactory == astFactory.getNodeFactory();
		unit.release();
		
		instrumentGlobalDefinitions(rootNode);

		instrumentBranchStatements(rootNode);

		if (civlcAST != null) {
			return this.combineASTs(
				civlcAST, 
				astFactory.newAST(rootNode, unit.getSourceFiles(), unit.isWholeProgram()));
		} else {
			return astFactory.newAST(rootNode, unit.getSourceFiles(), unit.isWholeProgram());
		}
	}
	
	/**
	 * Instrument definitions of global variables that serve to direct symbolic execution.
	 * These definitions are:
	 * 
	 *     int $direct_index = 0;
	 *     Bool_ $direct_array[] = { direction1, direction2, ... };
	 *     
	 * We insert these as the first definitions.
	 *     
	 * @param root
	 * @throws SyntaxException 
	 */
	private void instrumentGlobalDefinitions(SequenceNode<BlockItemNode> root) throws SyntaxException {
		List<BlockItemNode> directDecls = new ArrayList<BlockItemNode>();
				
		Source src = this.newSource(indexVarName, CivlcTokenConstant.TYPE);
		IdentifierNode branchIdxId = nodeFactory.newIdentifierNode(src,  indexVarName);
		IntegerConstantNode zero = nodeFactory.newIntegerConstantNode(src, "0");
		directDecls.add(nodeFactory.newVariableDeclarationNode(src, branchIdxId, basicType(BasicTypeKind.INT), zero));
		
		List<PairNode<DesignationNode, InitializerNode>> initList = new LinkedList<PairNode<DesignationNode, InitializerNode>>();
		for (Integer d : directions) {
			ExpressionNode initD = nodeFactory.newBooleanConstantNode(src, d.toString().equals("1"));
			initList.add(nodeFactory.newPairNode(src, null, initD));
		}
		CompoundInitializerNode branchInitializer = nodeFactory.newCompoundInitializerNode(src, initList);
		TypeNode arrayOfInt = nodeFactory.newArrayTypeNode(src, basicType(BasicTypeKind.BOOL), null);
		IdentifierNode branchArrayId = nodeFactory.newIdentifierNode(src, arrayVarName);
		directDecls.add(nodeFactory.newVariableDeclarationNode(src, branchArrayId, arrayOfInt, branchInitializer));
		
		root.insertChildren(0, directDecls);		
	}

	private void instrumentBranchStatements(ASTNode node) throws SyntaxException {
		/* Post-order traversal of AST */
		if (node != null) {
			Iterable<ASTNode> children = node.children();
			for (ASTNode child : children) {
				instrumentBranchStatements(child);
			}
		}
		
		if (node instanceof StatementNode) {
			String sourceFile = node.getSource().getFirstToken().getSourceFile().getName();
			if (directingFile.equals(sourceFile) ) {
				if (node instanceof IfNode) {
					int lineNum = ((IfNode)node).getCondition().getSource().getFirstToken().getLine();
					if ( directingLines.contains(Integer.valueOf(lineNum)) ) {
						if (debug) System.out.println("About to instrument if at line: "+lineNum);
						node.parent().setChild(node.childIndex(), instrumentedIf((IfNode)node));
					}
					
				} else if (node instanceof LoopNode) {
					LoopNode ln = (LoopNode) node;
					/* Check for the existence of a loop condition, a statement like:
					 *    for (;;)
					 * will not have directives, so we don't need to instrument it
					 */
					if (ln.getCondition() != null) {
						int lineNum = (ln.getCondition().getSource().getFirstToken().getLine());
						if ( directingLines.contains(Integer.valueOf(lineNum)) ) {
							node.parent().setChild(node.childIndex(), instrumentedLoop((LoopNode)node));
						}
					}

				} else if (node instanceof SwitchNode) {
					
					/* Get the line numbers of all the cases, including the default*/
					Iterator<LabeledStatementNode> casesIter = ((SwitchNode)node).getCases();
					Set<Integer> caseLineNums = new HashSet<Integer>();
					while (casesIter.hasNext()) {
						LabeledStatementNode lsn = casesIter.next();
						int caseLine = lsn.getSource().getFirstToken().getLine();
						if (debug) System.out.println("Case line at: "+caseLine);
						caseLineNums.add(caseLine);
					}
					int defaultLine = ((SwitchNode) node).getDefaultCase().getSource().getFirstToken().getLine();
					if (debug) System.out.println("Default line at: "+defaultLine);
					caseLineNums.add(defaultLine);
					
					/* Intersect the set of directing lines with the case statement lines */
					Set<Integer> caseDirectingLines = caseLineNums;
					caseDirectingLines.retainAll(directingLines);
					if ( !caseDirectingLines.isEmpty() ) {
						node.parent().setChild(node.childIndex(), instrumentedSwitch((SwitchNode)node, caseDirectingLines));
					}

				} 
			}
		}
	}

	/* Build:  
	 *   $assert($direct_index < directions.size() : "Concrete run differs from abstract run");
	 *   $assume($direct_array[LbranchIdx++] ? Cond : ! Cond );
	 */
	private StatementNode instrumentAssume(Source src, ExpressionNode cond) throws SyntaxException {
		/* Boolean casts not happening when the condition is "1" or "0", so forcing this with the following hack */
		if (cond.prettyRepresentation().toString().equals("1")) cond = nodeFactory.newBooleanConstantNode(src, true);
		if (cond.prettyRepresentation().toString().equals("0")) cond = nodeFactory.newBooleanConstantNode(src, false);
		
		ExpressionNode branchArray = nodeFactory.newIdentifierExpressionNode(src, nodeFactory.newIdentifierNode(src, arrayVarName));
		ExpressionNode branchIdx = nodeFactory.newIdentifierExpressionNode(src, nodeFactory.newIdentifierNode(src, indexVarName));
		List<ExpressionNode> accessArgs = new LinkedList<ExpressionNode>();
		accessArgs.add(branchArray);
		accessArgs.add(nodeFactory.newOperatorNode(src, Operator.POSTINCREMENT, Arrays.asList(branchIdx)));
		ExpressionNode branchAccess = nodeFactory.newOperatorNode(src, Operator.SUBSCRIPT, accessArgs);
	
		ExpressionNode negCond = nodeFactory.newOperatorNode(src, Operator.NOT, Arrays.asList(cond.copy()));
		
		List<ExpressionNode> plusArgs = new LinkedList<ExpressionNode>();
		plusArgs.add(branchAccess);
		plusArgs.add(cond.copy());
		plusArgs.add(negCond);
		
		ExpressionNode qmarkExpr = nodeFactory.newOperatorNode(src, Operator.CONDITIONAL, plusArgs);	
		//IntegerConstantNode oneNode = nodeFactory.newIntegerConstantNode(src,  "1");
		//ExpressionNode equalsNode = nodeFactory.newOperatorNode(src,  Operator.EQUALS, qmarkExpr, oneNode);
		
		IdentifierExpressionNode vAssume = nodeFactory.newIdentifierExpressionNode(src, nodeFactory.newIdentifierNode(src, "$assume"));
		StatementNode assumeStatement = nodeFactory.newExpressionStatementNode(nodeFactory.newFunctionCallNode(src, vAssume, Arrays.asList(qmarkExpr), null));
		
		/* This asserts that the branch index doesn't run past the array of given directions */
		StatementNode assertStatement = instrumentAssert(src, branchIdx);
		
		List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
		statements.add(assertStatement);
		statements.add(assumeStatement);
		
		return nodeFactory.newCompoundStatementNode(src, statements);
	}
	
	/* Construct an assert statement to check for indexing beyond branchArray */
	private StatementNode instrumentAssert(Source src, ExpressionNode branchIdx) throws SyntaxException {
		IntegerConstantNode bound = nodeFactory.newIntegerConstantNode(src, ((Integer) directions.size()).toString());
		
		List<ExpressionNode> assertArgs = new LinkedList<>();
		assertArgs.add(branchIdx.copy());
		assertArgs.add(bound);
		
		ExpressionNode ltExpr = nodeFactory.newOperatorNode(src, Operator.LT, assertArgs);
		IdentifierExpressionNode vAssert = nodeFactory.newIdentifierExpressionNode(src, nodeFactory.newIdentifierNode(src, "$assert")).copy();
		StatementNode assertStatement = nodeFactory.newExpressionStatementNode(nodeFactory.newFunctionCallNode(src, vAssert, Arrays.asList(ltExpr), null));
		
		return assertStatement;
	}
	
	/*
	 * Replace the given IfNode with a block 
	 */
	private StatementNode instrumentedIf(IfNode node) throws SyntaxException {
		List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
		statements.add(instrumentAssume(node.getSource(), node.getCondition()));
		statements.add(node.copy());
		return nodeFactory.newCompoundStatementNode(node.getSource(), statements);
	}
	
	private StatementNode instrumentedLoop(LoopNode node) throws SyntaxException {
		StatementNode result = null;
		
		Source src = node.getSource();
		ExpressionNode cond = node.getCondition();
		StatementNode body = node.getBody();
		SequenceNode<ContractNode> contracts = (node.loopContracts() != null) ? node.loopContracts().copy() : null;
		
		if (node.getKind() == LoopKind.WHILE) {

			ExpressionNode trueCondition = nodeFactory.newIntegerConstantNode(src, "1"); 
			
			List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
			statements.add(instrumentAssume(src, cond));
			
			StatementNode conditionalBreak = nodeFactory.newIfNode(src, 
					nodeFactory.newOperatorNode(src, Operator.NOT, Arrays.asList(cond.copy())),
					nodeFactory.newBreakNode(src));
			statements.add(conditionalBreak);
			
			statements.add(body.copy());
			StatementNode instrumentedBody = nodeFactory.newCompoundStatementNode(src, statements);
			
			result = nodeFactory.newWhileLoopNode(src, trueCondition, instrumentedBody, contracts);
			
		} else if (node.getKind() == LoopKind.DO_WHILE) {

			List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
			statements.add(body.copy());
			statements.add(instrumentAssume(src, cond));
			StatementNode instrumentedBody = nodeFactory.newCompoundStatementNode(src, statements);
			
			result = nodeFactory.newDoLoopNode(src, cond.copy(), instrumentedBody, contracts);
			
		} else {			
			ForLoopNode forLoop = (ForLoopNode)node;
			
			List<BlockItemNode> compoundItems = new LinkedList<BlockItemNode>();
			
			ForLoopInitializerNode init = forLoop.getInitializer();
			if (init instanceof DeclarationListNode) {
				for (VariableDeclarationNode vdn : (DeclarationListNode)init) {
					compoundItems.add(vdn.copy());
				}
			} else {
				compoundItems.add(nodeFactory.newExpressionStatementNode((ExpressionNode)init.copy()));
			}

			ExpressionNode trueCondition = nodeFactory.newIntegerConstantNode(src, "1");
			
			List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
			statements.add(instrumentAssume(src, cond));
			
			StatementNode conditionalBreak = nodeFactory.newIfNode(src, 
					nodeFactory.newOperatorNode(src, Operator.NOT, Arrays.asList(cond.copy())),
					nodeFactory.newBreakNode(src));
			statements.add(conditionalBreak);
			
			statements.add(body.copy());
			statements.add(nodeFactory.newExpressionStatementNode(forLoop.getIncrementer().copy()));
			
			StatementNode instrumentedBody = nodeFactory.newCompoundStatementNode(src, statements);
			
			compoundItems.add(nodeFactory.newWhileLoopNode(src,
					trueCondition,
					instrumentedBody,
					contracts));
			
			result = nodeFactory.newCompoundStatementNode(src, compoundItems);
		}
		return result;
	}
	
	private StatementNode instrumentedSwitch(SwitchNode node, Set<Integer> directingLines) throws SyntaxException {
		
		Source src = node.getSource();
		ExpressionNode swc = node.getCondition();
		List<BlockItemNode> statements = new LinkedList<BlockItemNode>();
		
		Iterator<LabeledStatementNode> casesIter = node.getCases();
		List<LabeledStatementNode> casesList = new ArrayList<>();
		casesIter.forEachRemaining(casesList::add); // Use a List so we can add the default case
		LabeledStatementNode defaultCase = node.getDefaultCase();
		casesList.add(defaultCase);
		
		for (LabeledStatementNode currCase : casesList) {
			
			int caseLine = currCase.getSource().getFirstToken().getLine();
			
			if (directingLines.contains(caseLine)) {
				
				if (node.getDefaultCase().equals(currCase)) {
					// default condition is the conjunction of the negation of all case label conditions
					Iterator<LabeledStatementNode> iter = node.getCases();
					if (iter.hasNext()) {
						LabeledStatementNode c = iter.next();
						SwitchLabelNode sln = (SwitchLabelNode) c.getLabel();
						ExpressionNode caseConst = sln.getExpression().copy();
						OperatorNode condition = 
								nodeFactory.newOperatorNode(src, Operator.NEQ, swc.copy(), caseConst);
						
						for (; iter.hasNext();) {
							c = iter.next();
							sln = (SwitchLabelNode) c.getLabel();
	
							// Copy the case constant to assemble the switch edge condition
							caseConst = sln.getExpression().copy();
							OperatorNode caseCompare = nodeFactory.newOperatorNode(src, Operator.NEQ, swc.copy(), caseConst);
	
							condition = nodeFactory.newOperatorNode(src, Operator.LAND, condition, caseCompare);
						}
						
						statements.add(instrumentAssume(src, condition));
					}
					
					
				} else {
					// match the case label and return its condition
					for (Iterator<LabeledStatementNode> iter = node.getCases(); iter.hasNext();) {
						LabeledStatementNode c = iter.next();

						if (c.equals(currCase)) {
							SwitchLabelNode sln = (SwitchLabelNode) c.getLabel();
							
							// Copy the case constant to assemble the switch edge condition
							ExpressionNode caseConst = sln.getExpression().copy();
							OperatorNode caseCompare = nodeFactory.newOperatorNode(src, Operator.EQUALS, swc.copy(), caseConst);
							statements.add(instrumentAssume(src, caseCompare));
						}
					}
				}
			}
		}
		
		assert (!statements.isEmpty()) : "Expected a matching case label";
		statements.add(node.copy());
		return nodeFactory.newCompoundStatementNode(node.getSource(), statements);
	}

}