Cuda2CIVLWorker.java
package dev.civl.mc.transform.common;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import dev.civl.abc.ast.IF.AST;
import dev.civl.abc.ast.IF.ASTFactory;
import dev.civl.abc.ast.entity.IF.Function;
import dev.civl.abc.ast.node.IF.ASTNode;
import dev.civl.abc.ast.node.IF.ASTNode.NodeKind;
import dev.civl.abc.ast.node.IF.IdentifierNode;
import dev.civl.abc.ast.node.IF.SequenceNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDefinitionNode;
import dev.civl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import dev.civl.abc.ast.node.IF.expression.CastNode;
import dev.civl.abc.ast.node.IF.expression.DotNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode.ExpressionKind;
import dev.civl.abc.ast.node.IF.expression.FunctionCallNode;
import dev.civl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import dev.civl.abc.ast.node.IF.expression.IntegerConstantNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode.Operator;
import dev.civl.abc.ast.node.IF.statement.BlockItemNode;
import dev.civl.abc.ast.node.IF.statement.CompoundStatementNode;
import dev.civl.abc.ast.node.IF.statement.ExpressionStatementNode;
import dev.civl.abc.ast.node.IF.statement.LoopNode;
import dev.civl.abc.ast.node.IF.statement.StatementNode;
import dev.civl.abc.ast.node.IF.statement.StatementNode.StatementKind;
import dev.civl.abc.ast.node.IF.type.ArrayTypeNode;
import dev.civl.abc.ast.node.IF.type.FunctionTypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode.TypeNodeKind;
import dev.civl.abc.ast.type.IF.FunctionType;
import dev.civl.abc.ast.type.IF.ObjectType;
import dev.civl.abc.ast.type.IF.PointerType;
import dev.civl.abc.ast.type.IF.QualifiedObjectType;
import dev.civl.abc.ast.type.IF.StandardBasicType;
import dev.civl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import dev.civl.abc.ast.type.IF.Type;
import dev.civl.abc.ast.type.IF.Type.TypeKind;
import dev.civl.abc.ast.type.common.CommonFunctionType;
import dev.civl.abc.token.IF.CivlcToken;
import dev.civl.abc.token.IF.Source;
import dev.civl.abc.token.IF.SourceFile;
import dev.civl.abc.token.IF.SyntaxException;
import dev.civl.mc.transform.IF.Cuda2CIVLTransformer;
public class Cuda2CIVLWorker extends BaseWorker {
private static String CUDA_HEADER = "cuda.h";
private int tempVarNum;
public Cuda2CIVLWorker(ASTFactory astFactory) {
super(Cuda2CIVLTransformer.LONG_NAME, astFactory);
this.identifierPrefix = "_cuda_";
}
@Override
protected AST transformCore(AST ast) throws SyntaxException {
if (!this.hasHeader(ast, CUDA_HEADER))
return ast;
SequenceNode<BlockItemNode> root = ast.getRootNode();
AST newAST;
ast.release();
removeBuiltinDefinitions(root);
translateCudaMallocCalls(root);
translateKernelCalls(root);
if (!this.has_gen_mainFunction(root)) {
transformMainFunction(root);
createNewMainFunction(root);
}
translateMainDefinition(root);
translateKernelDefinitions(root);
translateKernelDeclarations(root);
newAST = astFactory.newAST(root, ast.getSourceFiles(),
ast.isWholeProgram());
// newAST.prettyPrint(System.out, false);
return newAST;
}
/**
* Returns a new temporary variable each time it is called.
*
* @return A generated temporary variable name
*/
protected String newTemporaryVariableName() {
return this.identifierPrefix + "tmp" + tempVarNum++;
}
/**
* Finds the main function definition node underneath root and calls
* {@link Cuda2CIVLWorker#transformMainFunctionDefinition(FunctionDefinitionNode)}
* on it
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void translateMainDefinition(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DEFINITION) {
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.getName() != null
&& definition.getName().equals("main")) {
transformMainFunctionDefinition(definition);
return;
}
}
}
}
/**
* Transforms every kernel definition node under root using
* {@link Cuda2CIVLWorker#kernelDefinitionTransform(FunctionDefinitionNode)}.
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void translateKernelDefinitions(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DEFINITION) {
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.hasGlobalFunctionSpecifier()) {
root.setChild(child.childIndex(),
kernelDefinitionTransform(definition));
}
}
}
}
/**
* Transforms every kernel declaration node under root using
* {@link Cuda2CIVLWorker#kernelDeclarationTransform(FunctionDeclarationNode)}.
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void translateKernelDeclarations(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DECLARATION) {
FunctionDeclarationNode declaration = (FunctionDeclarationNode) child;
if (declaration.hasGlobalFunctionSpecifier() && declaration
.getTypeNode() instanceof FunctionTypeNode) {
root.setChild(child.childIndex(),
kernelDeclarationTransform(declaration));
}
}
}
}
/**
* Transforms every cuda malloc function call using
* {@link Cuda2CIVLWorker#cudaMallocTransform(FunctionCallNode)}.
* Cuda malloc calls are found by recursively searching through the
* AST for a matching function call.
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void translateCudaMallocCalls(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
if (expression
.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall.getFunction()
.expressionKind() == ExpressionKind.IDENTIFIER_EXPRESSION) {
IdentifierExpressionNode identifierExpression = (IdentifierExpressionNode) functionCall
.getFunction();
if (identifierExpression.getIdentifier().name()
.equals("cudaMalloc")) {
int index = functionCall.childIndex();
root.setChild(index,
cudaMallocTransform(functionCall));
continue;
}
}
}
}
translateCudaMallocCalls(child);
}
}
/**
* Transforms every kernel call using
* {@link Cuda2CIVLWorker#kernelCallTransform(FunctionCallNode)}.
* Kernel calls are found by recursively searching through the
* AST for a matching function call.
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void translateKernelCalls(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.STATEMENT) {
StatementNode statement = (StatementNode) child;
if (statement.statementKind() == StatementKind.EXPRESSION) {
ExpressionStatementNode expressionStatement = (ExpressionStatementNode) statement;
ExpressionNode expression = expressionStatement
.getExpression();
if (expression
.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall.getNumberOfContextArguments() > 0) {
root.setChild(statement.childIndex(),
kernelCallTransform(functionCall));
continue;
}
}
}
}
translateKernelCalls(child);
}
}
/**
* Translates "cudaMalloc( (void**) ptrPtr, size)" to "*ptrPtr =
* (type)$malloc($root, size), cudaSuccess" where "type" is the type of
* *ptrPtr
*
* @param cudaMallocCall a FunctionCallNode which is a call to cuda malloc
* @return The translated cuda malloc call as an expression node
*/
protected ExpressionNode cudaMallocTransform(
FunctionCallNode cudaMallocCall) {
Source source = cudaMallocCall.getSource();
// find the pointer
ExpressionNode ptrPtr = cudaMallocCall.getArgument(0);
while (ptrPtr instanceof CastNode) {
ptrPtr = ((CastNode) ptrPtr).getArgument();
}
ExpressionNode size = cudaMallocCall.getArgument(1);
// build lhs expression
ExpressionNode assignLhs = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.DEREFERENCE,
Arrays.asList(ptrPtr.copy()));
Type lhsType;
if (ptrPtr.getInitialType().kind() == TypeKind.POINTER) {
PointerType ptrType = (PointerType) ptrPtr.getInitialType();
lhsType = ptrType.referencedType();
} else {
lhsType = ptrPtr.getInitialType();
}
// build rhs expression
FunctionCallNode mallocCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$malloc"),
Arrays.asList(nodeFactory.newHereNode(source), size.copy()),
null);
CastNode mallocCast = nodeFactory.newCastNode(source,
this.typeNode(source, lhsType), mallocCall);
// create assign node
OperatorNode assignment = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.ASSIGN,
Arrays.asList(assignLhs, mallocCast));
// create comma node
ExpressionNode finalExpression = nodeFactory.newOperatorNode(source,
Operator.COMMA,
Arrays.asList(assignment,
nodeFactory.newEnumerationConstantNode(nodeFactory
.newIdentifierNode(source, "cudaSuccess"))));
return finalExpression;
}
/**
* Inserts a call to $cuda_init at the beginning of main and a call to
* $cuda_finalize at the end of main
*
* @param mainFunction the function definition node for the main function
*/
private void transformMainFunctionDefinition(
FunctionDefinitionNode mainFunction) {
Source source = mainFunction.getSource();
FunctionCallNode cudaInitCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_init"),
Collections.<ExpressionNode>emptyList(), null);
FunctionCallNode cudaFinalizeCall = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$cuda_finalize"),
Collections.<ExpressionNode>emptyList(), null);
CompoundStatementNode body = mainFunction.getBody();
body = this.insertToCompoundStatement(body,
nodeFactory.newExpressionStatementNode(cudaInitCall), 0);
body = this.insertToCompoundStatement(body,
nodeFactory.newExpressionStatementNode(cudaFinalizeCall),
body.numChildren());
mainFunction.setBody(body);
}
/**
* Given a kernel name, returns a transformed version of it to distinguish
* it as a transformed version of the original kernel
*
* @param name a string that is the name of the original kernel
* @return the transformed name
*/
private String transformKernelName(String name) {
return "_cuda_" + name;
}
/**
* Given a kernel definition node, this method transforms the kernel name
* (see {@link Cuda2CIVLWorker#transformKernelName(String)}), prepends
* formal parameters for the context arguments of the kernel, builds and
* inserts the inner kernel definition from the kernel's body (see
* {@link Cuda2CIVLWorker#buildInnerKernelDefinition(CompoundStatementNode)}),
* and enqueues a call to the inner kernel definition using
* $cuda_enqueue_kernel.
*
* @param oldDefinition a FunctionDefinitionNode which is the definition of the original kernel
* @return the transformed kernel definition
*/
protected FunctionDefinitionNode kernelDefinitionTransform(
FunctionDefinitionNode oldDefinition) {
// TODO: add execution configuration parameters as regular parameters
Source source = oldDefinition.getSource();
FunctionDefinitionNode innerKernelDefinition = this
.buildInnerKernelDefinition(oldDefinition.getBody());
String newKernelName = this
.transformKernelName(oldDefinition.getIdentifier().name());
FunctionCallNode enqueueKernelCall = nodeFactory.newFunctionCallNode(
source,
this.identifierExpression(source, "$cuda_enqueue_kernel"),
Arrays.asList(this.identifierExpression(source, "_cuda_stream"),
this.identifierExpression(source, "_cuda_kernel"),
this.identifierExpression(source, "gridDim"),
this.identifierExpression(source, "blockDim")),
null);
CompoundStatementNode newKernelBody = nodeFactory
.newCompoundStatementNode(source,
Arrays.asList(innerKernelDefinition,
nodeFactory.newExpressionStatementNode(
enqueueKernelCall)));
List<VariableDeclarationNode> newKernelFormalsList = new ArrayList<>();
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("gridDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("blockDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("_cuda_mem_size"), nodeFactory
.newTypedefNameNode(this.identifier("size_t"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("_cuda_stream"), nodeFactory.newTypedefNameNode(
this.identifier("cudaStream_t"), null)));
for (VariableDeclarationNode decl : oldDefinition.getTypeNode()
.getParameters()) {
if (decl.getTypeNode().kind() != TypeNodeKind.VOID)
newKernelFormalsList.add(decl.copy());
}
SequenceNode<VariableDeclarationNode> newKernelFormals = nodeFactory
.newSequenceNode(source, "kernel formals",
newKernelFormalsList);
FunctionTypeNode newKernelType = nodeFactory.newFunctionTypeNode(source,
oldDefinition.getTypeNode().getReturnType().copy(),
newKernelFormals, true);
FunctionDefinitionNode newKernel = nodeFactory
.newFunctionDefinitionNode(source,
nodeFactory.newIdentifierNode(source, newKernelName),
newKernelType, null, newKernelBody);
return newKernel;
}
/**
* Given a kernel declaration node, this method transforms the kernel name
* (see {@link Cuda2CIVLWorker#transformKernelName(String)}) and prepends
* formal parameters for the context arguments of the kernel.
*
* @param oldDeclaration a FunctionDeclarationNode which is the declaration of the original kernel
* @return the transformed kernel declaration node
*/
protected FunctionDeclarationNode kernelDeclarationTransform(
FunctionDeclarationNode oldDeclaration) {
Source source = oldDeclaration.getSource();
String newKernelName = this
.transformKernelName(oldDeclaration.getIdentifier().name());
List<VariableDeclarationNode> newKernelFormalsList = new ArrayList<>();
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("gridDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("blockDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("_cuda_mem_size"), nodeFactory
.newTypedefNameNode(this.identifier("size_t"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("_cuda_stream"), nodeFactory.newTypedefNameNode(
this.identifier("cudaStream_t"), null)));
FunctionTypeNode oldDeclarationTypeNode = ((FunctionTypeNode) oldDeclaration
.getTypeNode());
for (VariableDeclarationNode decl : oldDeclarationTypeNode
.getParameters()) {
if (decl.getTypeNode().kind() != TypeNodeKind.VOID)
newKernelFormalsList.add(decl.copy());
}
SequenceNode<VariableDeclarationNode> newKernelFormals = nodeFactory
.newSequenceNode(source, "kernel formals",
newKernelFormalsList);
FunctionTypeNode newKernelType = nodeFactory.newFunctionTypeNode(source,
oldDeclarationTypeNode.getReturnType().copy(), newKernelFormals,
true);
FunctionDeclarationNode newKernel = nodeFactory
.newFunctionDeclarationNode(source,
nodeFactory.newIdentifierNode(source, newKernelName),
newKernelType, null);
return newKernel;
}
/**
* Alters a body of code by removing any variable declaration
* with the "__shared__" tag and returning a new list of those removed declarations
* without the "__shared__" tag.
*
* @param statements a CompountStatementNode that is any section of code
* @return The list of removed variable declarations
*/
protected List<VariableDeclarationNode> extractSharedVariableDeclarations(
CompoundStatementNode statements) {
List<VariableDeclarationNode> declarations = new ArrayList<>();
for (BlockItemNode item : statements) {
if (item instanceof VariableDeclarationNode) {
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) item.copy();
if (variableDeclaration.hasSharedStorage()) {
statements.removeChild(item.childIndex());
variableDeclaration.setSharedStorage(false);
declarations.add(variableDeclaration);
}
}
}
return declarations;
}
/**
* Given the body of a kernel definition, this method builds the
* inner kernel for the transformed kernel, which aims to generate a grid of blocks
* and threads before running the body of the original kernel.
*
* This method defines the inner kernel with formal parameters,
* inserts the block function (see {@link Cuda2CIVLWorker#buildBlockDefinition(CompoundStatementNode)}),
* and appends to that calls to $cuda_wait_in_queue,
* $cuda_run_procs (for block generation), and $cuda_kernel_finish.
*
* @param body a CompoundStatementNode which is the body of the original kernel
* @return The completed inner kernel definition
*/
protected FunctionDefinitionNode buildInnerKernelDefinition(
CompoundStatementNode body) {
Source source = body.getSource();
VariableDeclarationNode thisDeclaration = nodeFactory
.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source, "_cuda_this"),
nodeFactory.newPointerTypeNode(source,
nodeFactory.newTypedefNameNode(
nodeFactory.newIdentifierNode(source,
"$cuda_kernel_instance_t"),
null)));
VariableDeclarationNode eDeclaration = nodeFactory
.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source, "_cuda_event"),
nodeFactory.newTypedefNameNode(nodeFactory
.newIdentifierNode(source, "cudaEvent_t"),
null));
SequenceNode<VariableDeclarationNode> innerKernelFormals = nodeFactory
.newSequenceNode(source, "innerKernelFormals",
Arrays.asList(thisDeclaration, eDeclaration));
FunctionDefinitionNode blockDefinition = buildBlockDefinition(body);
FunctionCallNode waitInQueueCall = nodeFactory.newFunctionCallNode(
source,
this.identifierExpression(source, "$cuda_wait_in_queue"),
Arrays.asList(this.identifierExpression(source, "_cuda_this"),
this.identifierExpression(source, "_cuda_event")),
null);
FunctionCallNode runProcsCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_run_procs"),
Arrays.asList(this.identifierExpression(source, "gridDim"),
this.identifierExpression(source, "_cuda_block")),
null);
FunctionCallNode kernelFinishCall = nodeFactory.newFunctionCallNode(
source,
this.identifierExpression(source, "$cuda_kernel_finish"),
Arrays.asList(this.identifierExpression(source, "_cuda_this")),
null);
CompoundStatementNode innerKernelBody = nodeFactory
.newCompoundStatementNode(source, Arrays.asList(blockDefinition,
nodeFactory.newExpressionStatementNode(waitInQueueCall),
nodeFactory.newExpressionStatementNode(runProcsCall),
nodeFactory
.newExpressionStatementNode(kernelFinishCall)));
FunctionDefinitionNode innerKernelDefinition = nodeFactory
.newFunctionDefinitionNode(source,
nodeFactory.newIdentifierNode(source, "_cuda_kernel"),
nodeFactory.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
innerKernelFormals, false),
null, innerKernelBody);
return innerKernelDefinition;
}
/**
* Given the body of a kernel definition, this method builds the
* block function within the inner kernel, which aims to create threads within
* a block before running the body of the original kernel.
*
* This method defines the block function with formal parameters,
* begins it with a barrier creation using $gbarrier_create, and appends to that
* the thread function (see {@link Cuda2CIVLWorker#buildThreadDefinition(CompoundStatementNode)},
* a call to $cuda_run_procs (for thread generation), and a call to $gbarrier_destroy.
*
* @param body a CompoundStatementNode which is the body of the original kernel
* @return The completed block function definition
*/
protected FunctionDefinitionNode buildBlockDefinition(
CompoundStatementNode body) {
Source source = body.getSource();
DotNode blockDimX = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "x"));
DotNode blockDimY = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "y"));
DotNode blockDimZ = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "z"));
OperatorNode totalThreads = nodeFactory
.newOperatorNode(source, Operator.TIMES,
Arrays.asList(
nodeFactory
.newOperatorNode(source, Operator.TIMES,
Arrays.<ExpressionNode>asList(
blockDimX, blockDimY)),
blockDimZ));
VariableDeclarationNode numThreads = nodeFactory.newVariableDeclarationNode(source,
this.identifier("numThreads"), nodeFactory.newBasicTypeNode(source, BasicTypeKind.INT));
numThreads.setInitializer(totalThreads);
FunctionCallNode newGbarrier = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$gbarrier_create"),
Arrays.asList(nodeFactory.newHereNode(source), this.identifierExpression("numThreads")),
null);
VariableDeclarationNode gCommCreate = nodeFactory.newVariableDeclarationNode(source,
this.identifier("gComm"), nodeFactory.newTypedefNameNode(this.identifier("$gcomm"), null));
List<ExpressionNode> createArguments = new ArrayList<ExpressionNode>();
createArguments.add(nodeFactory.newHereNode(source));
createArguments.add(this.identifierExpression("numThreads"));
gCommCreate.setInitializer(nodeFactory.newFunctionCallNode(source, this.identifierExpression("$gcomm_create"),
createArguments, null));
VariableDeclarationNode numWarps = nodeFactory.newVariableDeclarationNode(source,
this.identifier("numWarps"), nodeFactory.newBasicTypeNode(source, BasicTypeKind.INT));
OperatorNode totalWarps = nodeFactory.newOperatorNode(source, Operator.PLUS,
Arrays.asList(
nodeFactory.newOperatorNode(source, Operator.DIV, Arrays.asList(
this.identifierExpression("numThreads"), nodeFactory.newIntConstantNode(source, 32))),
nodeFactory.newOperatorNode(source, Operator.NEQ, Arrays.asList(
nodeFactory.newOperatorNode(source, Operator.MOD, Arrays.asList(
this.identifierExpression("numThreads"), nodeFactory.newIntConstantNode(source, 32)
)),
nodeFactory.newIntConstantNode(source, 0)))));
numWarps.setInitializer(totalWarps);
VariableDeclarationNode warpBarrierArray = nodeFactory.newVariableDeclarationNode(source,
this.identifier("warpBarriers"), nodeFactory.newArrayTypeNode(source,
nodeFactory.newTypedefNameNode(this.identifier("$gbarrier"), null), this.identifierExpression("numWarps")));
VariableDeclarationNode blockScope = nodeFactory.newVariableDeclarationNode(source,
this.identifier("_block_root"), nodeFactory.newTypedefNameNode(this.identifier("$scope"), null));
blockScope.setInitializer(nodeFactory.newHereNode(source));
List<VariableDeclarationNode> creationLoopInitializerList = new ArrayList<VariableDeclarationNode>();
VariableDeclarationNode forLoopInitializer = nodeFactory.newVariableDeclarationNode(source,
this.identifier("i"), nodeFactory.newBasicTypeNode(source, BasicTypeKind.INT));
forLoopInitializer.setInitializer(nodeFactory.newIntConstantNode(source, 0));
creationLoopInitializerList.add(forLoopInitializer);
List<BlockItemNode> creationLoopItems = new ArrayList<BlockItemNode>();
FunctionCallNode warpBarrierCreate = nodeFactory.newFunctionCallNode(source,
this.identifierExpression("$gbarrier_create"),
Arrays.asList(this.identifierExpression("_block_root"), nodeFactory.newIntConstantNode(source, 32)), null);
OperatorNode warpBarrierAssign = nodeFactory.newOperatorNode(source, Operator.ASSIGN,
Arrays.asList(
nodeFactory.newOperatorNode(source, Operator.SUBSCRIPT,
Arrays.asList(this.identifierExpression("warpBarriers"),
this.identifierExpression("i"))),
warpBarrierCreate));
creationLoopItems.add(nodeFactory.newExpressionStatementNode(warpBarrierAssign));
CompoundStatementNode creationLoopBody = nodeFactory.newCompoundStatementNode(source, creationLoopItems);
OperatorNode numWarpsMinusOne = nodeFactory.newOperatorNode(source, Operator.MINUS, Arrays.asList(this.identifierExpression("numWarps"),
nodeFactory.newIntConstantNode(source, 1)));
LoopNode warpBarrierLoop = nodeFactory.newForLoopNode(source,
nodeFactory.newForLoopInitializerNode(source, creationLoopInitializerList),
nodeFactory.newOperatorNode(source, Operator.LT, Arrays.asList(this.identifierExpression("i"),
numWarpsMinusOne.copy())),
nodeFactory.newOperatorNode(source, Operator.POSTINCREMENT, this.identifierExpression("i")),
creationLoopBody, null);
OperatorNode lastWarpNumThreads = nodeFactory.newOperatorNode(source, Operator.MINUS, Arrays.asList(
this.identifierExpression("numThreads"),
nodeFactory.newOperatorNode(source, Operator.TIMES, Arrays.asList(
numWarpsMinusOne.copy(),
nodeFactory.newIntConstantNode(source, 32)))));
FunctionCallNode lastWarpBarrierCreate = nodeFactory.newFunctionCallNode(source,
this.identifierExpression("$gbarrier_create"),
Arrays.asList(this.identifierExpression("_block_root"), lastWarpNumThreads), null);
OperatorNode LastWarpAssign = nodeFactory.newOperatorNode(source, Operator.ASSIGN,Arrays.asList(
nodeFactory.newOperatorNode(source, Operator.SUBSCRIPT, Arrays.asList(
this.identifierExpression("warpBarriers"),
numWarpsMinusOne.copy())),
lastWarpBarrierCreate));
VariableDeclarationNode gbarrierCreation = nodeFactory
.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"_cuda_block_barrier"),
nodeFactory.newTypedefNameNode(nodeFactory
.newIdentifierNode(source, "$gbarrier"), null),
newGbarrier);
List<VariableDeclarationNode> sharedVars = this
.extractSharedVariableDeclarations(body);
completeSharedExternArrays(sharedVars);
SequenceNode<VariableDeclarationNode> blockFormals = nodeFactory
.newSequenceNode(source, "blockFormals",
Arrays.asList(
nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"blockIdx"),
nodeFactory.newTypedefNameNode(
nodeFactory.newIdentifierNode(
source, "uint3"),
null))));
FunctionDefinitionNode threadDefinition = this
.buildThreadDefinition(body);
FunctionCallNode runProcsCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_run_procs"),
Arrays.asList(this.identifierExpression(source, "blockDim"),
this.identifierExpression(source, "_cuda_thread")),
null);
List<VariableDeclarationNode> destructionLoopInitializerList = new ArrayList<VariableDeclarationNode>();
destructionLoopInitializerList.add(forLoopInitializer.copy());
List<BlockItemNode> destructionLoopItems = new ArrayList<BlockItemNode>();
FunctionCallNode warpBarrierDestroy = nodeFactory.newFunctionCallNode(source,
this.identifierExpression("$gbarrier_destroy"),
Arrays.asList(nodeFactory.newOperatorNode(source, Operator.SUBSCRIPT,
Arrays.asList(this.identifierExpression("warpBarriers"), this.identifierExpression("i")))), null);
destructionLoopItems.add(nodeFactory.newExpressionStatementNode(warpBarrierDestroy));
CompoundStatementNode destructionLoopBody = nodeFactory.newCompoundStatementNode(source, destructionLoopItems);
LoopNode warpBarrierDestruction = nodeFactory.newForLoopNode(source,
nodeFactory.newForLoopInitializerNode(source, destructionLoopInitializerList),
nodeFactory.newOperatorNode(source, Operator.LT, Arrays.asList(this.identifierExpression("i"), this.identifierExpression("numWarps"))),
nodeFactory.newOperatorNode(source, Operator.POSTINCREMENT, this.identifierExpression("i")),
destructionLoopBody, null);
FunctionCallNode gbarrierDestruction = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$gbarrier_destroy"),
Arrays.asList(this.identifierExpression(source,
"_cuda_block_barrier")),
null);
List<ExpressionNode> destroyArguments = new ArrayList<ExpressionNode>();
destroyArguments.add(this.identifierExpression("gComm"));
IntegerConstantNode zero = nodeFactory.newIntConstantNode(source, 0);
CastNode NULL = nodeFactory.newCastNode(source, nodeFactory.newPointerTypeNode(source, nodeFactory.newVoidTypeNode(source)), zero);
destroyArguments.add(NULL);
FunctionCallNode gCommDestroy = nodeFactory.newFunctionCallNode(source, this.identifierExpression("$gcomm_destroy"),
destroyArguments, null);
List<BlockItemNode> blockBodyItems = new ArrayList<BlockItemNode>();
blockBodyItems.add(numThreads);
blockBodyItems.add(numWarps);
blockBodyItems.add(gCommCreate);
blockBodyItems.add(warpBarrierArray);
blockBodyItems.add(blockScope);
blockBodyItems.add(warpBarrierLoop);
blockBodyItems.add(nodeFactory.newExpressionStatementNode(LastWarpAssign));
blockBodyItems.add(gbarrierCreation);
blockBodyItems.addAll(sharedVars);
blockBodyItems.add(threadDefinition);
blockBodyItems
.add(nodeFactory.newExpressionStatementNode(runProcsCall));
blockBodyItems.add(
nodeFactory.newExpressionStatementNode(gbarrierDestruction));
blockBodyItems.add(warpBarrierDestruction);
blockBodyItems.add(nodeFactory.newExpressionStatementNode(gCommDestroy));
CompoundStatementNode blockBody = nodeFactory
.newCompoundStatementNode(source, blockBodyItems);
FunctionDefinitionNode blockDefinition = nodeFactory
.newFunctionDefinitionNode(source,
nodeFactory.newIdentifierNode(source, "_cuda_block"),
nodeFactory.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
blockFormals, false),
null, blockBody);
return blockDefinition;
}
/**
*
* @param sharedVars a list of VariableDeclarationNodes
*/
protected void completeSharedExternArrays(
List<VariableDeclarationNode> sharedVars) {
for (VariableDeclarationNode node : sharedVars) {
if (node.hasExternStorage()
&& node.getTypeNode().kind() == TypeNodeKind.ARRAY) {
ArrayTypeNode arrayType = (ArrayTypeNode) node.getTypeNode();
if (arrayType.getExtent() == null) {
arrayType.setExtent(
this.identifierExpression("_cuda_mem_size"));
node.setExternStorage(false);
}
}
}
}
/**
* Given the body of a kernel definition, this method builds the
* thread function within the block function of the inner kernel,
* which aims to create a barrier for the thread before running the body of the original kernel.
*
* This method defines the thread function with formal parameters and
* inserts into it the body of the original kernel among other functions calls
* for the creation/destruction of a barrier and for data race checking.
*
* @param body a CompoundStatementNode which is the body of the original kernel
* @return The completed thread function definition
*/
protected FunctionDefinitionNode buildThreadDefinition(
CompoundStatementNode body) {
translateShflCalls(body);
CompoundStatementNode newBody = translateAtomicCalls(body);
Source source = newBody.getSource();
SequenceNode<VariableDeclarationNode> threadFormals = nodeFactory
.newSequenceNode(source, "threadFormals",
Arrays.asList(
nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"threadIdx"),
nodeFactory.newTypedefNameNode(
nodeFactory.newIdentifierNode(
source, "uint3"),
null))));
FunctionCallNode localStart = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$local_start"),
Arrays.asList(), null);
VariableDeclarationNode tidDecl = nodeFactory
.newVariableDeclarationNode(source,
this.identifier("_cuda_tid"),
nodeFactory.newBasicTypeNode(source, BasicTypeKind.INT),
nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source,
"$cuda_index"),
Arrays.asList(
this.identifierExpression(source,
"blockDim"),
this.identifierExpression(source,
"threadIdx")),
null));
// Kernel_id
VariableDeclarationNode kidDecl = nodeFactory
.newVariableDeclarationNode(source,
this.identifier("_cuda_kid"),
nodeFactory.newBasicTypeNode(source, BasicTypeKind.INT),
nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source,
"$cuda_kernel_index"),
Arrays.asList(
this.identifierExpression(source,
"gridDim"),
this.identifierExpression(source,
"blockDim"),
this.identifierExpression(source,
"blockIdx"),
this.identifierExpression(source,
"threadIdx")),
null));
VariableDeclarationNode commCreate = nodeFactory.newVariableDeclarationNode(source,
this.identifier("comm"), nodeFactory.newTypedefNameNode(this.identifier("$comm"), null));
List<ExpressionNode> createArguments = new ArrayList<ExpressionNode>();
createArguments.add(nodeFactory.newHereNode(source));
createArguments.add(this.identifierExpression("gComm"));
createArguments.add(this.identifierExpression("_cuda_tid"));
commCreate.setInitializer(nodeFactory.newFunctionCallNode(source, this.identifierExpression("$comm_create"),
createArguments, null));
FunctionCallNode newBarrier = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$barrier_create"),
Arrays.asList(nodeFactory.newHereNode(source),
this.identifierExpression(source,
"_cuda_block_barrier"),
this.identifierExpression("_cuda_tid")),
null);
VariableDeclarationNode barrierCreation = nodeFactory
.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"_cuda_thread_barrier"),
nodeFactory.newTypedefNameNode(nodeFactory
.newIdentifierNode(source, "$barrier"), null),
newBarrier);
// FIXME: Not sure if this works with FunctionCallNode
FunctionCallNode readPop = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$read_set_pop"),
Arrays.asList(), null
);
FunctionCallNode writePop = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$write_set_pop"),
Arrays.asList(), null
);
FunctionCallNode barrierDestruction = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$barrier_destroy"),
Arrays.asList(this.identifierExpression(source,
"_cuda_thread_barrier")),
null);
List<ExpressionNode> destroyArguments = new ArrayList<ExpressionNode>();
destroyArguments.add(this.identifierExpression("comm"));
FunctionCallNode commDestroy = nodeFactory.newFunctionCallNode(source, this.identifierExpression("$comm_destroy"),
destroyArguments, null);
FunctionCallNode localEnd = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$local_end"),
Arrays.asList(), null);
// FIXME: Not sure if this works
FunctionCallNode readPush = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$read_set_push"),
Arrays.asList(), null
);
FunctionCallNode writePush = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$write_set_push"),
Arrays.asList(), null
);
// Node for check_data_race
FunctionCallNode checkDataRace = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$check_data_race"),
Arrays.asList(
this.identifierExpression(source,
"_cuda_this"),
this.identifierExpression(source,
"_cuda_kid")),
null);
List<BlockItemNode> threadBodyItems = new ArrayList<BlockItemNode>();
threadBodyItems.add(nodeFactory.newExpressionStatementNode(localStart));
threadBodyItems.add(tidDecl);
threadBodyItems.add(kidDecl);
threadBodyItems.add(commCreate);
threadBodyItems.add(barrierCreation);
// threadBodyItems.add(Node for read/write set push)
threadBodyItems.add(nodeFactory.newExpressionStatementNode(readPush));
threadBodyItems.add(nodeFactory.newExpressionStatementNode(writePush));
for (BlockItemNode child : newBody) {
if (child != null)
threadBodyItems.add(child.copy());
}
// check data race call (make Node)
threadBodyItems.add(nodeFactory.newExpressionStatementNode(checkDataRace));
// threadBodyItems.add(Node for read/write set pop)
threadBodyItems.add(nodeFactory.newExpressionStatementNode(readPop));
threadBodyItems.add(nodeFactory.newExpressionStatementNode(writePop));
threadBodyItems.add(
nodeFactory.newExpressionStatementNode(barrierDestruction));
threadBodyItems.add(nodeFactory.newExpressionStatementNode(commDestroy));
threadBodyItems.add(nodeFactory.newExpressionStatementNode(localEnd));
CompoundStatementNode threadBody = nodeFactory
.newCompoundStatementNode(source, threadBodyItems);
FunctionDefinitionNode threadDefinition = nodeFactory
.newFunctionDefinitionNode(source,
nodeFactory.newIdentifierNode(source, "_cuda_thread"),
nodeFactory.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
threadFormals, false),
null, threadBody);
//TODO Change into $cuda_barrier
FunctionCallNode cudaBarrier = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_barrier"),
Arrays.asList(
this.identifierExpression(source,
"_cuda_this"),
this.identifierExpression(source,
"_cuda_kid"),
this.identifierExpression(source,
"_cuda_thread_barrier")),
null);
replaceSyncThreadsCalls(threadDefinition, cudaBarrier);
return threadDefinition;
}
/**
* Replaces all calls to "__synchthreads" with the replacement expression passed in.
* The AST is searched through recursively to find all function calls matching "__syncthreads".
*
* @param root the root node of an Abstract Syntax Tree
* @param replacement an ExpressionNode which will replace all instances of "__synchthreads"
*/
protected void replaceSyncThreadsCalls(ASTNode root,
ExpressionNode replacement) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child instanceof ExpressionNode) {
ExpressionNode itemExpr = (ExpressionNode) child;
if (itemExpr instanceof FunctionCallNode) {
FunctionCallNode call = (FunctionCallNode) itemExpr;
ExpressionNode function = call.getFunction();
if (function instanceof IdentifierExpressionNode) {
String functionName = ((IdentifierExpressionNode) function)
.getIdentifier().name();
if (functionName.equals("__syncthreads")) {
root.setChild(child.childIndex(),
replacement.copy());
continue;
}
}
}
}
replaceSyncThreadsCalls(child, replacement);
}
}
/**
* Searches through the kernel and performs a transformation on all atomic function
* calls to ensure data races are accurately caught. This method acts on the statement
* level and copies transformed statements to build a new kernel.
*
* @param root the root node of an Abstract Syntax Tree (root of a kernel on initial call of this method)
* @return A CompoundStatementNode with transformed statements (overall returns a new transformed kernel)
*/
protected CompoundStatementNode translateAtomicCalls(ASTNode root) {
List<BlockItemNode> newKernelItems = new ArrayList<BlockItemNode>();
for (Iterator<ASTNode> it = root.children().iterator(); it.hasNext();) {
ASTNode child = it.next();
if (child == null)
continue;
if (child.nodeKind() == NodeKind.STATEMENT) {
StatementNode statement = (StatementNode) child;
StatementKind statementKind = statement.statementKind();
if (statementKind == StatementKind.COMPOUND) {
newKernelItems.add(translateAtomicCalls(child));
continue;
}
}
replaceAtomicExpressions(child, newKernelItems);
newKernelItems.add((BlockItemNode)child.copy());
}
CompoundStatementNode newKernelBody = nodeFactory.newCompoundStatementNode(root.getSource(), newKernelItems);
return newKernelBody;
}
/**
* Takes in an ExpressionNode and determines if it is an atomic function call.
* If the expression is such a call, the matching FunctionDefinitionNode is returned,
* otherwise returns null.
*
* @param expression an ExpressionNode
* @return the FunctionDefinitionNode of the atomic function call that is passed in.
* If expression is not an atomic function call, null is returned.
*/
protected FunctionDefinitionNode findAtomicDefinition(ExpressionNode expression) {
if(expression.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode call = (FunctionCallNode) expression;
ExpressionNode function = call.getFunction();
if (function instanceof IdentifierExpressionNode) {
IdentifierNode identifier = ((IdentifierExpressionNode) function).getIdentifier();
String functionName = identifier.name();
if (functionName.toLowerCase().contains("atomic")) {
Function functionEntity = (Function) identifier.getEntity();
FunctionDefinitionNode functionDefinition = functionEntity.getDefinition();
Source source = functionDefinition.getSource();
CivlcToken token = source.getFirstToken();
SourceFile sourceFile = token.getSourceFile();
String fileName = sourceFile.getName();
if(fileName.equals("cuda.cvl"))
return functionDefinition;
}
}
}
return null;
}
/**
* Extracts the expressions that are passed in as parameters to an atomic function call
* and stores them in temporary variables. Recurses if a parameter is also an atomic
* function call. The temporary variable assignments are added to the new kernel.
*
* @param functionCall the FunctionCallNode that is having its parameters transformed
* @param newKernelItems a list of BlockItemNodes that will be built into the new kernel
*/
protected void transformParameters(FunctionCallNode functionCall, List<BlockItemNode> newKernelItems) {
for (ASTNode child: functionCall.getArguments()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
if (expression.expressionKind() == ExpressionKind.CONSTANT)
continue;
FunctionDefinitionNode functionDefinition = findAtomicDefinition(expression);
if (functionDefinition != null) {
transformParameters((FunctionCallNode) expression, newKernelItems);
functionCall.setChild(child.childIndex(), atomicCallTransform((FunctionCallNode) expression, functionDefinition, newKernelItems));
continue;
}
replaceAtomicExpressions(child, newKernelItems);
Source source = expression.getSource();
Type type = expression.getType();
TypeNode typeNode;
if (type instanceof CommonFunctionType) {
FunctionType functionType = (FunctionType) type;
ObjectType returnType = functionType.getReturnType();
typeNode = this.typeNode(source, returnType);
}
else
typeNode = this.typeNode(source, type);
String tmpVariableName = "$" + this.newTemporaryVariableName();
VariableDeclarationNode tmpDeclaration = nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source, tmpVariableName),
typeNode.copy());
tmpDeclaration.setInitializer(expression.copy());
newKernelItems.add(tmpDeclaration);
functionCall.setArgument(child.childIndex(),
nodeFactory.newIdentifierExpressionNode(source,
this.identifier(tmpVariableName)));
}
}
}
/**
* Searches through an AST in order to transform any atomic function calls. This method
* searches on the expression level, thus it only replaces expressions in the parent
* statements, and does not copy the statements into the new kernel.
*
* @param root the root of an Abstract Syntax Tree
* @param newKernelItems a list of BlockItemNodes that will be built into the new kernel
*/
protected void replaceAtomicExpressions(ASTNode root, List<BlockItemNode> newKernelItems) {
for (ASTNode child: root.children()) {
if (child == null)
continue;
if(child.nodeKind() == NodeKind.STATEMENT) {
StatementNode statement = (StatementNode) child;
if (statement.statementKind() == StatementKind.COMPOUND) {
root.setChild(child.childIndex(), translateAtomicCalls(statement));
continue;
}
List<BlockItemNode> newStatementItems = new ArrayList<BlockItemNode>();
replaceAtomicExpressions(child, newStatementItems);
newStatementItems.add((BlockItemNode)child.copy());
CompoundStatementNode newStatementBody = nodeFactory.newCompoundStatementNode(root.getSource(), newStatementItems);
root.setChild(child.childIndex(), newStatementBody);
continue;
}
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
FunctionDefinitionNode functionDefinition = findAtomicDefinition(expression);
if (functionDefinition != null){
transformParameters((FunctionCallNode) expression, newKernelItems);
root.setChild(child.childIndex(), atomicCallTransform((FunctionCallNode) expression, functionDefinition, newKernelItems));
continue;
}
}
replaceAtomicExpressions(child, newKernelItems);
}
}
/**
* Performs the transformation of an atomic function call. This includes creating a
* temporary variable to store the result of the function call, as well as checking
* for data races and clearing memory sets within an atomic block. The series of statements
* is added to the new kernel and the temporary variable identifier is returned.
*
* @param atomicCall the FunctionCallNode that is an atomic function call
* @param functionDefinition the FunctionDefinitionNode of the atomic function call
* @param newKernelItems a list of BlockItemNodes that will be built into the new kernel
* @return an IdentifierExpressionNode that holds the identifier for the temporary variable created in
* the transformation
*/
protected IdentifierExpressionNode atomicCallTransform(FunctionCallNode atomicCall, FunctionDefinitionNode functionDefinition, List<BlockItemNode> newKernelItems) {
Source source = atomicCall.getSource();
FunctionCallNode publish = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$publish"),
Arrays.asList(
this.identifierExpression(source,
"_cuda_this"),
this.identifierExpression(source,
"_cuda_kid")),
null);
// Node for check_data_race
FunctionCallNode checkDataRace = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$check_data_race"),
Arrays.asList(
this.identifierExpression(source,
"_cuda_this"),
this.identifierExpression(source,
"_cuda_kid")),
null);
FunctionCallNode clearMemSets = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$clear_mem_sets"),
Arrays.asList(
this.identifierExpression(source, "_cuda_this"),
this.identifierExpression(source, "_cuda_kid")),
null);
FunctionCallNode yeild = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$yield"),
Arrays.asList(), null);
FunctionTypeNode functionType = (FunctionTypeNode) functionDefinition.getTypeNode();
TypeNode functionReturnType = functionType.getReturnType();
String tmpVariableName = "$" + this.newTemporaryVariableName();
VariableDeclarationNode tmpDeclaration = nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source, tmpVariableName),
functionReturnType.copy());
tmpDeclaration.setInitializer(atomicCall.copy());
newKernelItems.add(nodeFactory.newExpressionStatementNode(publish));
newKernelItems.add(nodeFactory.newExpressionStatementNode(yeild.copy()));
newKernelItems.add(tmpDeclaration);
newKernelItems.add(nodeFactory.newExpressionStatementNode(checkDataRace));
newKernelItems.add(nodeFactory.newExpressionStatementNode(clearMemSets));
return nodeFactory.newIdentifierExpressionNode(source, this.identifier(tmpVariableName));
}
protected void translateShflCalls(ASTNode root) {
for (ASTNode child: root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
if (expression.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
ExpressionNode function = functionCall.getFunction();
if (function instanceof IdentifierExpressionNode) {
IdentifierNode identifier = ((IdentifierExpressionNode) function).getIdentifier();
String functionName = identifier.name();
if (functionName.toLowerCase().contains("__shfl") && functionName.toLowerCase().contains("_sync")) {
Function functionEntity = (Function) identifier.getEntity();
FunctionDeclarationNode functionDeclaration = (FunctionDeclarationNode) functionEntity.getDeclaration(0);
Source source = functionDeclaration.getSource();
CivlcToken token = source.getFirstToken();
SourceFile sourceFile = token.getSourceFile();
String fileName = sourceFile.getName();
if(fileName.equals("cuda.h")) {
shflCallTransform(functionCall);
}
}
}
}
}
translateShflCalls(child);
}
}
protected void shflCallTransform(FunctionCallNode shflCall) {
Source source = shflCall.getSource();
SequenceNode<ExpressionNode> arguments = shflCall.getArguments();
if (arguments.numChildren() < 4) {
arguments.addSequenceChild(nodeFactory.newIntConstantNode(source, 32));
}
arguments.addSequenceChild(this.identifierExpression("numThreads"));
arguments.addSequenceChild(this.identifierExpression("_cuda_tid"));
arguments.addSequenceChild(this.identifierExpression("comm"));
arguments.addSequenceChild(this.identifierExpression("warpBarriers"));
ExpressionNode function = shflCall.getFunction();
IdentifierNode identifier = ((IdentifierExpressionNode) function).getIdentifier();
String functionName = identifier.name();
String cudaFunctionName = "_cuda" + functionName;
identifier.setName(cudaFunctionName);
}
/**
* Transforms the kernel call to instead use the kernel's transformed
* signature as transformed by
* {@link Cuda2CIVLWorker#kernelDeclarationTransform(FunctionDeclarationNode)}.
*
* @param kernelCall a FunctionCallNode which is a kernel call
* @return The transformed kernel call
*/
protected StatementNode kernelCallTransform(FunctionCallNode kernelCall) {
Source source = kernelCall.getSource();
List<VariableDeclarationNode> tempVarDecls = new ArrayList<>();
List<ExpressionNode> newArgumentList = new ArrayList<>();
for (int i = 0; i < 2; i++) {
ExpressionNode arg = kernelCall.getContextArgument(i);
Type argType = arg.getConvertedType();
if (argType.kind() == TypeKind.QUALIFIED) {
argType = ((QualifiedObjectType) argType).getBaseType();
}
if (argType.kind() == TypeKind.BASIC
&& ((StandardBasicType) argType)
.getBasicTypeKind() == BasicTypeKind.INT) {
String tmpVar = newTemporaryVariableName();
ExpressionNode intConvertedToDim3 = nodeFactory
.newFunctionCallNode(source,
this.identifierExpression("$cuda_to_dim3"),
Arrays.asList(arg.copy()), null);
tempVarDecls
.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier(tmpVar),
nodeFactory.newTypedefNameNode(
this.identifier("dim3"), null),
intConvertedToDim3));
newArgumentList.add(this.identifierExpression(tmpVar));
} else {
newArgumentList.add(arg.copy());
}
}
if (kernelCall.getNumberOfContextArguments() < 3) {
try {
newArgumentList
.add(nodeFactory.newIntegerConstantNode(source, "0"));
} catch (SyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
newArgumentList.add(kernelCall.getContextArgument(2).copy());
}
if (kernelCall.getNumberOfContextArguments() < 4) {
try {
newArgumentList
.add(nodeFactory.newIntegerConstantNode(source, "0"));
} catch (SyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
newArgumentList.add(kernelCall.getContextArgument(3).copy());
}
for (int i = 0; i < kernelCall.getNumberOfArguments(); i++) {
newArgumentList.add(kernelCall.getArgument(i).copy());
}
ExpressionNode newFunction;
if (kernelCall.getFunction() instanceof IdentifierExpressionNode) {
IdentifierExpressionNode identifierExpression = (IdentifierExpressionNode) kernelCall
.getFunction();
newFunction = this.identifierExpression(transformKernelName(
identifierExpression.getIdentifier().name()));
} else {
newFunction = kernelCall.getFunction().copy();
}
FunctionCallNode newFunctionCall = nodeFactory.newFunctionCallNode(
source, newFunction, newArgumentList, null);
List<BlockItemNode> blockItems = new ArrayList<>();
blockItems.addAll(tempVarDecls);
blockItems.add(nodeFactory.newExpressionStatementNode(newFunctionCall));
CompoundStatementNode replacementNode = nodeFactory
.newCompoundStatementNode(source, blockItems);
return replacementNode;
}
/**
* Removes all definitions of the variables "threadIdx", "blockIdx", "gridDim", and "blockDim"
* that exist in the original CUDA code. The AST is searched recursively to find all variable
* declarations with a matching name.
*
* @param root the root node of an Abstract Syntax Tree
*/
protected void removeBuiltinDefinitions(ASTNode root) {
Set<String> builtinVariables = new HashSet<>(
Arrays.asList("threadIdx", "blockIdx", "gridDim", "blockDim"));
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.VARIABLE_DECLARATION) {
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) child;
if (variableDeclaration.getIdentifier() != null
&& variableDeclaration.getIdentifier().getSource()
.getFirstToken().getSourceFile().getName()
.equals("cuda.h")
&& builtinVariables.contains(
variableDeclaration.getIdentifier().name())) {
root.removeChild(child.childIndex());
continue;
}
}
removeBuiltinDefinitions(child);
}
}
}