Cuda2CIVLWorker.java
package edu.udel.cis.vsl.civl.transform.common;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import edu.udel.cis.vsl.abc.ast.IF.AST;
import edu.udel.cis.vsl.abc.ast.IF.ASTFactory;
import edu.udel.cis.vsl.abc.ast.node.IF.ASTNode;
import edu.udel.cis.vsl.abc.ast.node.IF.ASTNode.NodeKind;
import edu.udel.cis.vsl.abc.ast.node.IF.SequenceNode;
import edu.udel.cis.vsl.abc.ast.node.IF.declaration.FunctionDefinitionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.CastNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.DotNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.ExpressionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.ExpressionNode.ExpressionKind;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.FunctionCallNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.OperatorNode;
import edu.udel.cis.vsl.abc.ast.node.IF.expression.OperatorNode.Operator;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.BlockItemNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.CompoundStatementNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.ExpressionStatementNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.StatementNode;
import edu.udel.cis.vsl.abc.ast.node.IF.statement.StatementNode.StatementKind;
import edu.udel.cis.vsl.abc.ast.node.IF.type.ArrayTypeNode;
import edu.udel.cis.vsl.abc.ast.node.IF.type.FunctionTypeNode;
import edu.udel.cis.vsl.abc.ast.node.IF.type.TypeNode.TypeNodeKind;
import edu.udel.cis.vsl.abc.ast.type.IF.PointerType;
import edu.udel.cis.vsl.abc.ast.type.IF.QualifiedObjectType;
import edu.udel.cis.vsl.abc.ast.type.IF.StandardBasicType;
import edu.udel.cis.vsl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import edu.udel.cis.vsl.abc.ast.type.IF.Type;
import edu.udel.cis.vsl.abc.ast.type.IF.Type.TypeKind;
import edu.udel.cis.vsl.abc.token.IF.Source;
import edu.udel.cis.vsl.abc.token.IF.SyntaxException;
import edu.udel.cis.vsl.civl.transform.IF.Cuda2CIVLTransformer;
public class Cuda2CIVLWorker extends BaseWorker {
private static String CUDA_HEADER = "cuda.h";
private int tempVarNum;
public Cuda2CIVLWorker(ASTFactory astFactory) {
super(Cuda2CIVLTransformer.LONG_NAME, astFactory);
this.identifierPrefix = "_cuda_";
}
@Override
public AST transform(AST ast) throws SyntaxException {
if (!this.hasHeader(ast, CUDA_HEADER))
return ast;
SequenceNode<BlockItemNode> root = ast.getRootNode();
AST newAST;
ast.release();
removeBuiltinDefinitions(root);
translateCudaMallocCalls(root);
translateKernelCalls(root);
if (!this.has_gen_mainFunction(root)) {
transformMainFunction(root);
createNewMainFunction(root);
}
translateMainDefinition(root);
translateKernelDefinitions(root);
newAST = astFactory.newAST(root, ast.getSourceFiles(),
ast.isWholeProgram());
// newAST.prettyPrint(System.out, false);
return newAST;
}
protected String newTemporaryVariableName() {
return this.identifierPrefix + "tmp" + tempVarNum++;
}
protected void translateMainDefinition(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DEFINITION) {
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.getName() != null
&& definition.getName().equals("main")) {
transformMainFunctionDefinition(definition);
}
}
}
}
protected void translateKernelDefinitions(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DEFINITION) {
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.hasGlobalFunctionSpecifier()) {
root.setChild(child.childIndex(),
kernelDefinitionTransform(definition));
}
}
}
}
protected void translateCudaMallocCalls(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
if (expression.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall.getFunction().expressionKind() == ExpressionKind.IDENTIFIER_EXPRESSION) {
IdentifierExpressionNode identifierExpression = (IdentifierExpressionNode) functionCall
.getFunction();
if (identifierExpression.getIdentifier().name()
.equals("cudaMalloc")) {
int index = functionCall.childIndex();
root.setChild(index,
cudaMallocTransform(functionCall));
continue;
}
}
}
}
translateCudaMallocCalls(child);
}
}
protected void translateKernelCalls(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.STATEMENT) {
StatementNode statement = (StatementNode) child;
if (statement.statementKind() == StatementKind.EXPRESSION) {
ExpressionStatementNode expressionStatement = (ExpressionStatementNode) statement;
ExpressionNode expression = expressionStatement
.getExpression();
if (expression.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall.getNumberOfContextArguments() > 0) {
root.setChild(statement.childIndex(),
kernelCallTransform(functionCall));
continue;
}
}
}
}
translateKernelCalls(child);
}
}
// translates "cudaMalloc( (void**) ptrPtr, size)" to
// "*ptrPtr = (type)$malloc($root, size), cudaSuccess"
// where "type" is the type of *ptrPtr
protected ExpressionNode cudaMallocTransform(FunctionCallNode cudaMallocCall) {
Source source = cudaMallocCall.getSource();
// find the pointer
ExpressionNode ptrPtr = cudaMallocCall.getArgument(0);
while (ptrPtr instanceof CastNode) {
ptrPtr = ((CastNode) ptrPtr).getArgument();
}
ExpressionNode size = cudaMallocCall.getArgument(1);
// build lhs expression
ExpressionNode assignLhs = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.DEREFERENCE,
Arrays.asList(ptrPtr.copy()));
Type lhsType;
if (ptrPtr.getInitialType().kind() == TypeKind.POINTER) {
PointerType ptrType = (PointerType) ptrPtr.getInitialType();
lhsType = ptrType.referencedType();
} else {
lhsType = ptrPtr.getInitialType();
}
// build rhs expression
FunctionCallNode mallocCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$malloc"),
Arrays.asList(nodeFactory.newHereNode(source), size.copy()),
null);
CastNode mallocCast = nodeFactory.newCastNode(source,
this.typeNode(source, lhsType), mallocCall);
// create assign node
OperatorNode assignment = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.ASSIGN,
Arrays.asList(assignLhs, mallocCast));
// create comma node
ExpressionNode finalExpression = nodeFactory.newOperatorNode(source,
Operator.COMMA, Arrays.asList(assignment, nodeFactory
.newEnumerationConstantNode(nodeFactory
.newIdentifierNode(source, "cudaSuccess"))));
return finalExpression;
}
private void transformMainFunctionDefinition(
FunctionDefinitionNode mainFunction) {
Source source = mainFunction.getSource();
FunctionCallNode cudaInitCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_init"),
Collections.<ExpressionNode> emptyList(), null);
FunctionCallNode cudaFinalizeCall = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$cuda_finalize"),
Collections.<ExpressionNode> emptyList(), null);
CompoundStatementNode body = mainFunction.getBody();
body = this.insertToCompoundStatement(body,
nodeFactory.newExpressionStatementNode(cudaInitCall), 0);
body = this.insertToCompoundStatement(body,
nodeFactory.newExpressionStatementNode(cudaFinalizeCall),
body.numChildren());
mainFunction.setBody(body);
}
private String transformKernelName(String name) {
return "_cuda_" + name;
}
protected FunctionDefinitionNode kernelDefinitionTransform(
FunctionDefinitionNode oldDefinition) {
// TODO: add execution configuration parameters as regular parameters
Source source = oldDefinition.getSource();
FunctionDefinitionNode innerKernelDefinition = this
.buildInnerKernelDefinition(oldDefinition.getBody());
String newKernelName = this.transformKernelName(oldDefinition
.getIdentifier().name());
FunctionCallNode enqueueKernelCall = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source,
"$cuda_enqueue_kernel"), Arrays.asList(
this.identifierExpression(source, "_cuda_stream"),
this.identifierExpression(source, "_cuda_kernel")),
null);
CompoundStatementNode newKernelBody = nodeFactory
.newCompoundStatementNode(source, Arrays.asList(
innerKernelDefinition, nodeFactory
.newExpressionStatementNode(enqueueKernelCall)));
List<VariableDeclarationNode> newKernelFormalsList = new ArrayList<>();
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("gridDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("blockDim"),
nodeFactory.newTypedefNameNode(this.identifier("dim3"), null)));
newKernelFormalsList
.add(nodeFactory.newVariableDeclarationNode(
source,
this.identifier("_cuda_mem_size"),
nodeFactory.newTypedefNameNode(
this.identifier("size_t"), null)));
newKernelFormalsList.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier("_cuda_stream"), nodeFactory
.newTypedefNameNode(this.identifier("cudaStream_t"),
null)));
for (VariableDeclarationNode decl : oldDefinition.getTypeNode()
.getParameters()) {
newKernelFormalsList.add(decl.copy());
}
SequenceNode<VariableDeclarationNode> newKernelFormals = nodeFactory
.newSequenceNode(source, "kernel formals", newKernelFormalsList);
FunctionTypeNode newKernelType = nodeFactory.newFunctionTypeNode(
source, oldDefinition.getTypeNode().getReturnType().copy(),
newKernelFormals, true);
FunctionDefinitionNode newKernel = nodeFactory
.newFunctionDefinitionNode(source,
nodeFactory.newIdentifierNode(source, newKernelName),
newKernelType, null, newKernelBody);
return newKernel;
}
protected List<VariableDeclarationNode> extractSharedVariableDeclarations(
CompoundStatementNode statements) {
List<VariableDeclarationNode> declarations = new ArrayList<>();
for (BlockItemNode item : statements) {
if (item instanceof VariableDeclarationNode) {
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) item;
if (variableDeclaration.hasSharedStorage()) {
statements.removeChild(item.childIndex());
variableDeclaration.setSharedStorage(false);
declarations.add(variableDeclaration);
}
}
}
return declarations;
}
protected FunctionDefinitionNode buildInnerKernelDefinition(
CompoundStatementNode body) {
Source source = body.getSource();
VariableDeclarationNode thisDeclaration = nodeFactory
.newVariableDeclarationNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_this"), nodeFactory
.newPointerTypeNode(source, nodeFactory
.newTypedefNameNode(nodeFactory
.newIdentifierNode(source,
"$cuda_kernel_instance_t"),
null)));
VariableDeclarationNode eDeclaration = nodeFactory
.newVariableDeclarationNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_event"), nodeFactory
.newTypedefNameNode(nodeFactory.newIdentifierNode(
source, "cudaEvent_t"), null));
SequenceNode<VariableDeclarationNode> innerKernelFormals = nodeFactory
.newSequenceNode(source, "innerKernelFormals",
Arrays.asList(thisDeclaration, eDeclaration));
FunctionDefinitionNode blockDefinition = buildBlockDefinition(body);
FunctionCallNode waitInQueueCall = nodeFactory
.newFunctionCallNode(source, this.identifierExpression(source,
"$cuda_wait_in_queue"), Arrays.asList(
this.identifierExpression(source, "_cuda_this"),
this.identifierExpression(source, "_cuda_event")), null);
FunctionCallNode runProcsCall = nodeFactory
.newFunctionCallNode(source, this.identifierExpression(source,
"$cuda_run_procs"), Arrays.asList(
this.identifierExpression(source, "gridDim"),
this.identifierExpression(source, "_cuda_block")), null);
FunctionCallNode kernelFinishCall = nodeFactory.newFunctionCallNode(
source,
this.identifierExpression(source, "$cuda_kernel_finish"),
Arrays.asList(this.identifierExpression(source, "_cuda_this")),
null);
CompoundStatementNode innerKernelBody = nodeFactory
.newCompoundStatementNode(
source,
Arrays.asList(
blockDefinition,
nodeFactory
.newExpressionStatementNode(waitInQueueCall),
nodeFactory
.newExpressionStatementNode(runProcsCall),
nodeFactory
.newExpressionStatementNode(kernelFinishCall)));
FunctionDefinitionNode innerKernelDefinition = nodeFactory
.newFunctionDefinitionNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_kernel"), nodeFactory
.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
innerKernelFormals, false), null,
innerKernelBody);
return innerKernelDefinition;
}
protected FunctionDefinitionNode buildBlockDefinition(
CompoundStatementNode body) {
Source source = body.getSource();
CompoundStatementNode threadBody = body.copy();
DotNode blockDimX = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "x"));
DotNode blockDimY = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "y"));
DotNode blockDimZ = nodeFactory.newDotNode(source,
this.identifierExpression(source, "blockDim"),
nodeFactory.newIdentifierNode(source, "z"));
OperatorNode numThreads = nodeFactory.newOperatorNode(source,
Operator.TIMES, Arrays.asList(nodeFactory.newOperatorNode(
source, Operator.TIMES,
Arrays.<ExpressionNode> asList(blockDimX, blockDimY)),
blockDimZ));
FunctionCallNode newGbarrier = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$gbarrier_create"),
Arrays.asList(nodeFactory.newHereNode(source), numThreads),
null);
VariableDeclarationNode gbarrierCreation = nodeFactory
.newVariableDeclarationNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_block_barrier"),
nodeFactory.newTypedefNameNode(nodeFactory
.newIdentifierNode(source, "$gbarrier"), null),
newGbarrier);
List<VariableDeclarationNode> sharedVars = this
.extractSharedVariableDeclarations(threadBody);
completeSharedExternArrays(sharedVars);
SequenceNode<VariableDeclarationNode> blockFormals = nodeFactory
.newSequenceNode(source, "blockFormals", Arrays
.asList(nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"blockIdx"), nodeFactory
.newTypedefNameNode(nodeFactory
.newIdentifierNode(source,
"uint3"), null))));
FunctionDefinitionNode threadDefinition = this
.buildThreadDefinition(threadBody);
FunctionCallNode runProcsCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$cuda_run_procs"), Arrays
.asList(this.identifierExpression(source, "blockDim"),
this.identifierExpression(source,
"_cuda_thread")), null);
FunctionCallNode gbarrierDestruction = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$gbarrier_destroy"),
Arrays.asList(this.identifierExpression(source,
"_cuda_block_barrier")), null);
List<BlockItemNode> blockBodyItems = new ArrayList<BlockItemNode>();
blockBodyItems.add(gbarrierCreation);
blockBodyItems.addAll(sharedVars);
blockBodyItems.add(threadDefinition);
blockBodyItems
.add(nodeFactory.newExpressionStatementNode(runProcsCall));
blockBodyItems.add(nodeFactory
.newExpressionStatementNode(gbarrierDestruction));
CompoundStatementNode blockBody = nodeFactory.newCompoundStatementNode(
source, blockBodyItems);
FunctionDefinitionNode blockDefinition = nodeFactory
.newFunctionDefinitionNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_block"), nodeFactory
.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
blockFormals, false), null, blockBody);
return blockDefinition;
}
protected void completeSharedExternArrays(
List<VariableDeclarationNode> sharedVars) {
for (VariableDeclarationNode node : sharedVars) {
if (node.hasExternStorage()
&& node.getTypeNode().kind() == TypeNodeKind.ARRAY) {
ArrayTypeNode arrayType = (ArrayTypeNode) node.getTypeNode();
if (arrayType.getExtent() == null) {
arrayType.setExtent(this
.identifierExpression("_cuda_mem_size"));
node.setExternStorage(false);
}
}
}
}
protected FunctionDefinitionNode buildThreadDefinition(
CompoundStatementNode body) {
Source source = body.getSource();
SequenceNode<VariableDeclarationNode> threadFormals = nodeFactory
.newSequenceNode(source, "threadFormals", Arrays
.asList(nodeFactory.newVariableDeclarationNode(source,
nodeFactory.newIdentifierNode(source,
"threadIdx"), nodeFactory
.newTypedefNameNode(nodeFactory
.newIdentifierNode(source,
"uint3"), null))));
VariableDeclarationNode tidDecl = nodeFactory
.newVariableDeclarationNode(source, this
.identifier("_cuda_tid"), nodeFactory.newBasicTypeNode(
source, BasicTypeKind.INT), nodeFactory
.newFunctionCallNode(source, this.identifierExpression(
source, "$cuda_index"),
Arrays.asList(this.identifierExpression(source,
"blockDim"), this.identifierExpression(
source, "threadIdx")), null));
FunctionCallNode newBarrier = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$barrier_create"), Arrays
.asList(nodeFactory.newHereNode(source), this
.identifierExpression(source,
"_cuda_block_barrier"), this
.identifierExpression("_cuda_tid")), null);
VariableDeclarationNode barrierCreation = nodeFactory
.newVariableDeclarationNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_thread_barrier"),
nodeFactory.newTypedefNameNode(nodeFactory
.newIdentifierNode(source, "$barrier"), null),
newBarrier);
FunctionCallNode barrierDestruction = nodeFactory.newFunctionCallNode(
source, this.identifierExpression(source, "$barrier_destroy"),
Arrays.asList(this.identifierExpression(source,
"_cuda_thread_barrier")), null);
List<BlockItemNode> threadBodyItems = new ArrayList<BlockItemNode>();
threadBodyItems.add(tidDecl);
threadBodyItems.add(barrierCreation);
for (BlockItemNode child : body) {
if (child != null)
threadBodyItems.add(child.copy());
}
threadBodyItems.add(nodeFactory
.newExpressionStatementNode(barrierDestruction));
CompoundStatementNode threadBody = nodeFactory
.newCompoundStatementNode(source, threadBodyItems);
FunctionDefinitionNode threadDefinition = nodeFactory
.newFunctionDefinitionNode(source, nodeFactory
.newIdentifierNode(source, "_cuda_thread"), nodeFactory
.newFunctionTypeNode(source,
nodeFactory.newVoidTypeNode(source),
threadFormals, false), null, threadBody);
FunctionCallNode barrierCall = nodeFactory.newFunctionCallNode(source,
this.identifierExpression(source, "$barrier_call"), Arrays
.asList(this.identifierExpression(source,
"_cuda_thread_barrier")), null);
replaceSyncThreadsCalls(threadDefinition, barrierCall);
return threadDefinition;
}
protected void replaceSyncThreadsCalls(ASTNode root,
ExpressionNode replacement) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child instanceof ExpressionNode) {
ExpressionNode itemExpr = (ExpressionNode) child;
if (itemExpr instanceof FunctionCallNode) {
FunctionCallNode call = (FunctionCallNode) itemExpr;
ExpressionNode function = call.getFunction();
if (function instanceof IdentifierExpressionNode) {
String functionName = ((IdentifierExpressionNode) function)
.getIdentifier().name();
if (functionName.equals("__syncthreads")) {
root.setChild(child.childIndex(),
replacement.copy());
continue;
}
}
}
}
replaceSyncThreadsCalls(child, replacement);
}
}
protected StatementNode kernelCallTransform(FunctionCallNode kernelCall) {
Source source = kernelCall.getSource();
List<VariableDeclarationNode> tempVarDecls = new ArrayList<>();
List<ExpressionNode> newArgumentList = new ArrayList<>();
for (int i = 0; i < 2; i++) {
ExpressionNode arg = kernelCall.getContextArgument(i);
Type argType = arg.getConvertedType();
if (argType.kind() == TypeKind.QUALIFIED) {
argType = ((QualifiedObjectType) argType).getBaseType();
}
if (argType.kind() == TypeKind.BASIC
&& ((StandardBasicType) argType).getBasicTypeKind() == BasicTypeKind.INT) {
String tmpVar = newTemporaryVariableName();
ExpressionNode intConvertedToDim3 = nodeFactory
.newFunctionCallNode(source,
this.identifierExpression("$cuda_to_dim3"),
Arrays.asList(arg.copy()), null);
tempVarDecls.add(nodeFactory.newVariableDeclarationNode(source,
this.identifier(tmpVar), nodeFactory
.newTypedefNameNode(this.identifier("dim3"),
null), intConvertedToDim3));
newArgumentList.add(this.identifierExpression(tmpVar));
} else {
newArgumentList.add(arg.copy());
}
}
if (kernelCall.getNumberOfContextArguments() < 3) {
try {
newArgumentList.add(nodeFactory.newIntegerConstantNode(source,
"0"));
} catch (SyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
newArgumentList.add(kernelCall.getContextArgument(2).copy());
}
if (kernelCall.getNumberOfContextArguments() < 4) {
try {
newArgumentList.add(nodeFactory.newIntegerConstantNode(source,
"0"));
} catch (SyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
newArgumentList.add(kernelCall.getContextArgument(3).copy());
}
for (int i = 0; i < kernelCall.getNumberOfArguments(); i++) {
newArgumentList.add(kernelCall.getArgument(i).copy());
}
ExpressionNode newFunction;
if (kernelCall.getFunction() instanceof IdentifierExpressionNode) {
IdentifierExpressionNode identifierExpression = (IdentifierExpressionNode) kernelCall
.getFunction();
newFunction = this
.identifierExpression(transformKernelName(identifierExpression
.getIdentifier().name()));
} else {
newFunction = kernelCall.getFunction().copy();
}
FunctionCallNode newFunctionCall = nodeFactory.newFunctionCallNode(
source, newFunction, newArgumentList, null);
List<BlockItemNode> blockItems = new ArrayList<>();
blockItems.addAll(tempVarDecls);
blockItems.add(nodeFactory.newExpressionStatementNode(newFunctionCall));
CompoundStatementNode replacementNode = nodeFactory
.newCompoundStatementNode(source, blockItems);
return replacementNode;
}
protected void removeBuiltinDefinitions(ASTNode root) {
Set<String> builtinVariables = new HashSet<>(Arrays.asList("threadIdx",
"blockIdx", "gridDim", "blockDim"));
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.VARIABLE_DECLARATION) {
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) child;
if (variableDeclaration.getIdentifier() != null
&& variableDeclaration.getIdentifier().getSource()
.getFirstToken().getSourceFile().getName()
.equals("cuda.h")
&& builtinVariables.contains(variableDeclaration
.getIdentifier().name())) {
root.removeChild(child.childIndex());
continue;
}
}
removeBuiltinDefinitions(child);
}
}
}