OpenMPOrphanWorker.java

package edu.udel.cis.vsl.civl.transform.common;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import edu.udel.cis.vsl.abc.ast.IF.AST;
import edu.udel.cis.vsl.abc.ast.IF.ASTFactory;
import edu.udel.cis.vsl.abc.ast.entity.IF.Function;
import edu.udel.cis.vsl.abc.ast.node.IF.ASTNode;
import edu.udel.cis.vsl.abc.ast.node.IF.IdentifierNode;
import edu.udel.cis.vsl.abc.ast.node.IF.SequenceNode;
import edu.udel.cis.vsl.abc.ast.node.IF.declaration.FunctionDefinitionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.ExpressionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.FunctionCallNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.omp.OmpExecutableNode;
import edu.udel.cis.vsl.abc.ast.node.IF.omp.OmpForNode;
import edu.udel.cis.vsl.abc.ast.node.IF.omp.OmpParallelNode;
import edu.udel.cis.vsl.abc.ast.node.IF.omp.OmpSyncNode;
import edu.udel.cis.vsl.abc.ast.node.IF.omp.OmpWorksharingNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.BlockItemNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.CompoundStatementNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.StatementNode;
import edu.udel.cis.vsl.abc.ast.type.IF.FunctionType;
import edu.udel.cis.vsl.abc.ast.type.IF.PointerType;
import edu.udel.cis.vsl.abc.front.IF.CivlcTokenConstant;
import edu.udel.cis.vsl.abc.token.IF.SyntaxException;
import edu.udel.cis.vsl.civl.transform.IF.OpenMPOrphanTransformer;
import edu.udel.cis.vsl.civl.util.IF.Pair;
import edu.udel.cis.vsl.civl.util.IF.Triple;

/**
 * This transformer transforms away the orphaned constructs of OpenMP programs.
 * 
 */
public class OpenMPOrphanWorker extends BaseWorker {

	private ArrayList<Triple<FunctionDefinitionNode, FunctionCallNode, Boolean>> functionCalls = new ArrayList<Triple<FunctionDefinitionNode, FunctionCallNode, Boolean>>();
	private ArrayList<Pair<StatementNode, FunctionDefinitionNode>> nodesToInsert = new ArrayList<Pair<StatementNode, FunctionDefinitionNode>>();

	public OpenMPOrphanWorker(ASTFactory astFactory) {
		super(OpenMPOrphanTransformer.LONG_NAME, astFactory);
		this.identifierPrefix = "$omp_orphan_";
	}

	@Override
	public AST transform(AST ast) throws SyntaxException {
		SequenceNode<BlockItemNode> root = ast.getRootNode();
		AST newAst;
		ast.release();
		FunctionDefinitionNode main = ast.getMain().getDefinition();
		ArrayList<String> visitedFuncs = new ArrayList<String>();
		ompOrphan(main, null, false, visitedFuncs);
		int i = 0;
		for (Pair<StatementNode, FunctionDefinitionNode> insert : nodesToInsert) {
			insertChildAt(i, insert.left, insert.right);
			i++;
		}
		newAst = astFactory.newAST(root, ast.getSourceFiles(),
				ast.isWholeProgram());
		// newAst.prettyPrint(System.out, true);
		return newAst;
	}

	private void ompOrphan(ASTNode node, Set<Function> callees,
			boolean isInParallel, ArrayList<String> visitedFuncs) {
		if (node instanceof OmpParallelNode) {
			isInParallel = true;
		} else if (node instanceof FunctionDefinitionNode) {
			callees = ((FunctionDefinitionNode) node).getEntity().getCallees();
		} else if (node instanceof FunctionCallNode) {
			FunctionCallNode fcn = (FunctionCallNode) node;
			FunctionType funType = null;
			if (fcn.getFunction() instanceof IdentifierExpressionNode) {
				IdentifierNode calledFunId = ((IdentifierExpressionNode) fcn
						.getFunction()).getIdentifier();

				// Call directly to a function
				if (calledFunId.getEntity() instanceof Function) {

				} else {
					// Call through an expression (an identifier)
					PointerType pFunType = (PointerType) fcn.getFunction()
							.getConvertedType();
					funType = (FunctionType) pFunType.referencedType();
				}
			} else {
				funType = (FunctionType) fcn.getFunction().getConvertedType();

			}

			ExpressionNode func = ((FunctionCallNode) node).getFunction();
			String funcName = null;
			if (func instanceof IdentifierExpressionNode) {
				funcName = ((IdentifierExpressionNode) ((FunctionCallNode) node)
						.getFunction()).getIdentifier().name();
			}

			if (callees != null) {
				boolean found = false;
				for (Function call : callees) {
					if (call.getName().equals(funcName)) {
						processFunction(call, (FunctionCallNode) node,
								isInParallel, funcName, visitedFuncs, callees);
						found = true;
					}
				}

				if (!found) {
					for (Function call : callees) {
						if (call.getType().equals(funType)) {
							funcName = call.getName();
							processFunction(call, (FunctionCallNode) node,
									isInParallel, funcName, visitedFuncs,
									callees);
						}
					}
				}
			}
		}

		if (node != null) {
			Iterable<ASTNode> children = node.children();
			for (ASTNode child : children) {
				ompOrphan(child, callees, isInParallel, visitedFuncs);
			}
		}

	}

	private void processFunction(Function call, FunctionCallNode node,
			boolean isInParallel, String funcName,
			ArrayList<String> visitedFuncs, Set<Function> callees) {
		FunctionDefinitionNode orphan = call.getDefinition();
		boolean isOrphan = checkOrphan(orphan);

		if (orphan != null) {
			ASTNode parentFunc = node.parent();
			while (!(parentFunc instanceof FunctionDefinitionNode)) {
				parentFunc = parentFunc.parent();
			}

			Triple<FunctionDefinitionNode, FunctionCallNode, Boolean> temp;
			temp = new Triple<>((FunctionDefinitionNode) parentFunc,
					(FunctionCallNode) node, isInParallel);
			functionCalls.add(temp);
		}

		if (isOrphan) {
			ArrayList<FunctionDefinitionNode> funcs = new ArrayList<FunctionDefinitionNode>();
			funcs.add(orphan);
			ASTNode parent = node;
			boolean direct = false;
			direct = insertFuncs((FunctionCallNode) node, funcs);

			if (!direct) {
				parent = node;
				while (!(parent instanceof FunctionDefinitionNode)) {
					parent = parent.parent();
				}
				boolean foundPar = false;
				int count = 0;

				FunctionDefinitionNode currDef = orphan;
				FunctionCallNode origCall = null;
				funcs = new ArrayList<FunctionDefinitionNode>();
				while (!foundPar && count < functionCalls.size()) {
					for (Triple<FunctionDefinitionNode, FunctionCallNode, Boolean> triple : functionCalls) {
						if (((IdentifierExpressionNode) triple.second
								.getFunction()).getIdentifier().name()
								.equals(currDef.getIdentifier().name())) {
							funcs.add(currDef);
							currDef = triple.first;
							count = 0;
							if (triple.third) {
								foundPar = true;
								origCall = triple.second;
							}
							break;
						}
						count++;
					}
				}
				if (foundPar) {
					insertFuncs(origCall, funcs);
				}
			}
		}
		if (orphan != null) {
			if (!visitedFuncs.contains(funcName)) {
				visitedFuncs.add(funcName);
				ompOrphan(orphan, callees, false, visitedFuncs);
			}
		}
	}

	private boolean insertFuncs(FunctionCallNode node,
			ArrayList<FunctionDefinitionNode> funcs) {
		ASTNode parent = node;
		boolean direct = false;
		while (parent != null) {
			if (parent instanceof OmpParallelNode) {
				direct = true;
				StatementNode statement = ((OmpParallelNode) parent)
						.statementNode();
				int index = statement.childIndex();
				CompoundStatementNode body;

				if (!(statement instanceof CompoundStatementNode)) {
					List<BlockItemNode> items = new LinkedList<BlockItemNode>();
					statement.remove();
					for (FunctionDefinitionNode func : funcs) {
						items.add(func.copy());
						removeOmpConstruct(func);
					}
					items.add(statement);
					body = nodeFactory.newCompoundStatementNode(
							newSource("Orphan",
									CivlcTokenConstant.COMPOUND_STATEMENT),
							items);
					parent.setChild(index, body);
				} else {
					for (FunctionDefinitionNode func : funcs) {
						Pair<StatementNode, FunctionDefinitionNode> tempPair = new Pair<>(
								statement, func.copy());
						nodesToInsert.add(tempPair);
						removeOmpConstruct(func);
					}
				}

			}
			parent = parent.parent();
		}
		return direct;
	}

	private boolean checkOrphan(ASTNode node) {
		boolean isOrphan = true;
		boolean foundOmpNode = false;
		if (node instanceof OmpForNode || node instanceof OmpSyncNode
				|| node instanceof OmpWorksharingNode) {

			// Check if some parent is a OmpParallelNode
			ASTNode parent = node;
			foundOmpNode = true;

			while (parent != null) {
				if (parent instanceof OmpParallelNode) {
					isOrphan = false;
					break;
				}
				parent = parent.parent();
			}
		}

		if (node != null) {
			Iterable<ASTNode> children = node.children();
			for (ASTNode child : children) {
				isOrphan = foundOmpNode = (isOrphan && foundOmpNode)
						|| checkOrphan(child);
			}
		}

		return isOrphan && foundOmpNode;
	}

	/*
	 * This method assumes that all of the OMP statements that are encountered
	 * can be safely removed or transformed into non-OMP equivalents.
	 */
	private void removeOmpConstruct(ASTNode node) {
		if (node instanceof OmpExecutableNode) {
			// Remove "statement" node from "omp statement" node
			StatementNode stmt = ((OmpExecutableNode) node).statementNode();
			int stmtIndex = getChildIndex(node, stmt);
			assert stmtIndex != -1;
			node.removeChild(stmtIndex);

			// Link "statement" into the "omp workshare" parent
			ASTNode parent = node.parent();
			int parentIndex = getChildIndex(parent, node);
			assert parentIndex != -1;
			parent.setChild(parentIndex, stmt);

			removeOmpConstruct(stmt);

		} else if (node instanceof FunctionCallNode
				&& ((FunctionCallNode) node).getFunction() instanceof IdentifierExpressionNode
				&& ((IdentifierExpressionNode) ((FunctionCallNode) node)
						.getFunction()).getIdentifier().name()
						.startsWith("omp_")) {
			/*
			 * Replace
			 */
			String ompFunctionName = ((IdentifierExpressionNode) ((FunctionCallNode) node)
					.getFunction()).getIdentifier().name();
			ASTNode replacement = null;
			if (ompFunctionName.equals("omp_get_thread_num")) {
				try {
					replacement = nodeFactory.newIntegerConstantNode(
							node.getSource(), "0");
				} catch (SyntaxException e) {
					e.printStackTrace();
				}
			} else if (ompFunctionName.equals("omp_get_num_threads")
					|| ompFunctionName.equals("omp_get_max_threads")
					|| ompFunctionName.equals("omp_get_num_procs")
					|| ompFunctionName.equals("omp_get_thread_limit")) {
				try {
					replacement = nodeFactory.newIntegerConstantNode(
							node.getSource(), "1");
				} catch (SyntaxException e) {
					e.printStackTrace();
				}

			} else if (ompFunctionName.equals("omp_init_lock")
					|| ompFunctionName.equals("omp_set_lock")
					|| ompFunctionName.equals("omp_unset_lock")
					|| ompFunctionName.equals("omp_set_num_threads")) {
				// delete this node
				replacement = nodeFactory
						.newNullStatementNode(node.getSource());

			} else if (ompFunctionName.equals("omp_get_wtime")) {
				// this will be transformed by the OMP transformer

			} else {
				assert false : "Unsupported omp function call "
						+ ompFunctionName
						+ " cannot be replaced by OpenMP simplifier";
			}

			// Link "replacement" into the omp call's parent

			if (!ompFunctionName.equals("omp_get_wtime")) {
				ASTNode parent = node.parent();
				int parentIndex = getChildIndex(parent, node);
				assert parentIndex != -1;
				parent.setChild(parentIndex, replacement);
			}
		} else if (node != null) {
			Iterable<ASTNode> children = node.children();
			for (ASTNode child : children) {
				removeOmpConstruct(child);
			}
		}
	}

	/*
	 * Returns the index of "child" in the children of "node"; -1 if "child" is
	 * not one of "node"'s children.
	 */
	private int getChildIndex(ASTNode node, ASTNode child) {
		for (int childIndex = 0; childIndex < node.numChildren(); childIndex++) {
			if (node.child(childIndex) == child)
				return childIndex;
		}
		return -1;
	}

	private void insertChildAt(int k, ASTNode parent, ASTNode nodeToInsert) {
		int numChildren = parent.numChildren();

		if (k >= numChildren) {
			parent.setChild(k, nodeToInsert);
		} else {
			ASTNode current = parent.removeChild(k);
			ASTNode next = null;
			parent.setChild(k, nodeToInsert);
			if (current != null) {
				for (int i = k + 1; i <= numChildren; i++) {
					if (i == numChildren) {
						parent.setChild(i, current);
						break;
					}
					next = parent.child(i);
					if (next != null) {
						parent.removeChild(i);
						parent.setChild(i, current);
					} else {
						parent.setChild(i, current);
						break;
					}
					current = next;

				}
			}
		}
	}
}