Cuda2CIVLWorker.java
package dev.civl.mc.transform.common;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import dev.civl.abc.ast.IF.AST;
import dev.civl.abc.ast.IF.ASTFactory;
import dev.civl.abc.ast.entity.IF.Function;
import dev.civl.abc.ast.node.IF.ASTNode;
import dev.civl.abc.ast.node.IF.ASTNode.NodeKind;
import dev.civl.abc.ast.node.IF.SequenceNode;
import dev.civl.abc.ast.node.IF.declaration.DeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.EnumeratorDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.FieldDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.FunctionDefinitionNode;
import dev.civl.abc.ast.node.IF.declaration.TypedefDeclarationNode;
import dev.civl.abc.ast.node.IF.declaration.VariableDeclarationNode;
import dev.civl.abc.ast.node.IF.expression.CastNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode;
import dev.civl.abc.ast.node.IF.expression.ExpressionNode.ExpressionKind;
import dev.civl.abc.ast.node.IF.expression.FunctionCallNode;
import dev.civl.abc.ast.node.IF.expression.IdentifierExpressionNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode;
import dev.civl.abc.ast.node.IF.expression.OperatorNode.Operator;
import dev.civl.abc.ast.node.IF.label.SwitchLabelNode;
import dev.civl.abc.ast.node.IF.statement.BlockItemNode;
import dev.civl.abc.ast.node.IF.statement.CompoundStatementNode;
import dev.civl.abc.ast.node.IF.statement.ExpressionStatementNode;
import dev.civl.abc.ast.node.IF.statement.StatementNode;
import dev.civl.abc.ast.node.IF.statement.StatementNode.StatementKind;
import dev.civl.abc.ast.node.IF.type.ArrayTypeNode;
import dev.civl.abc.ast.node.IF.type.EnumerationTypeNode;
import dev.civl.abc.ast.node.IF.type.FunctionTypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode;
import dev.civl.abc.ast.node.IF.type.TypeNode.TypeNodeKind;
import dev.civl.abc.ast.type.IF.EnumerationType;
import dev.civl.abc.ast.type.IF.PointerType;
import dev.civl.abc.ast.type.IF.QualifiedObjectType;
import dev.civl.abc.ast.type.IF.StandardBasicType;
import dev.civl.abc.ast.type.IF.StandardBasicType.BasicTypeKind;
import dev.civl.abc.ast.type.IF.Type;
import dev.civl.abc.ast.type.IF.Type.TypeKind;
import dev.civl.abc.front.IF.CivlcTokenConstant;
import dev.civl.abc.token.IF.CivlcToken;
import dev.civl.abc.token.IF.Formation;
import dev.civl.abc.token.IF.Source;
import dev.civl.abc.token.IF.StringLiteral;
import dev.civl.abc.token.IF.StringToken;
import dev.civl.abc.token.IF.SyntaxException;
import dev.civl.abc.token.IF.TokenFactory;
import dev.civl.abc.token.IF.CivlcToken.TokenVocabulary;
import dev.civl.mc.model.IF.CIVLSyntaxException;
import dev.civl.mc.transform.IF.Cuda2CIVLTransformer;
import dev.civl.mc.util.IF.Pair;
public class Cuda2CIVLWorker extends BaseWorker {
private static final String CUDA_HEADER = "cuda.h";
private int tempVarNum;
private static final Set<String> builtinCUDAVariables = Set.of("threadIdx",
"blockIdx", "gridDim", "blockDim");
private static final String CUDA_TAG_ENUM_NAME = "$cuda_tag";
private static final String HOST_COMM_NAME = "$cuda_host_comm";
private static final String HOST_PLACE_NAME = "$CUDA_PLACE_HOST";
private static final String DEVICE_PLACE_NAME = "$CUDA_PLACE_DEVICE";
private static final String DEVICE_GLOB_CONTEXT_NAME = "$cuda_global_context";
private static final String HOST_MAIN = "$host_main";
private static final String CUDA_MAIN = "$cuda_main";
private EnumerationType cudaTagEnumType = null;
private Map<String, KernelInfo> kernelMap = new HashMap<String, KernelInfo>();
private Set<ExpressionStatementNode> kernelCalls = new HashSet<ExpressionStatementNode>();
private Map<String, Function> deviceFunctionMap = new HashMap<String, Function>();
private int numSyncthreads = 0;
protected class KernelInfo {
public Function entity;
public KernelInfo(Function entity) {
this.entity = entity;
}
public CompoundStatementNode kernelBody = null;
/**
* @return the name of the enum value that is used to signal to the
* device to launch this kernel.
*/
public String getTagName() {
return "$CUDA_TAG_LAUNCH_" + entity.getName();
}
public String getParamStructName() {
return "$cuda_" + entity.getName() + "_data";
}
public String getArgRevealFunctionName() {
return "$cuda_reveal_" + entity.getName() + "_args";
}
public String getKernelProcName() {
return "$cuda_" + entity.getName() + "_proc";
}
public String getLaunchFunctionName() {
return "$cuda_host_launch_" + entity.getName();
}
public FunctionDefinitionNode getDefinition() {
return entity.getDefinition();
}
/**
* Generates a new struct (wrapped inside of a typedef to avoid needing
* to prepend the "struct" keyword before uses of this type) that has a
* field for each formal parameter that this kernel takes.
*
* This struct is used for passing passing kernel parameters from the
* host to the device via communicators.
*/
public TypedefDeclarationNode generateParameterStruct() {
String srcMethod = entity.getName() + ".generateParameterStruct";
List<Pair<String, String>> contextParams = contextParams(false);
List<FieldDeclarationNode> fieldList = new ArrayList<FieldDeclarationNode>(
entity.getType().getNumParameters() + contextParams.size());
for (Pair<String, String> contextParam : contextParams) {
fieldList.add(nodeDeclField(srcMethod, contextParam.left,
nodeTypeNamed(srcMethod, contextParam.right)));
}
for (VariableDeclarationNode param : getDefinition().getTypeNode()
.getParameters()) {
if (param.getTypeNode().kind() == TypeNodeKind.VOID)
continue;
fieldList.add(nodeDeclField(srcMethod, param.getName(),
param.getTypeNode().copy()));
}
for (FieldDeclarationNode field : fieldList) {
field.getTypeNode().setConstQualified(false);
}
return nodeTypeDefStruct(srcMethod, getParamStructName(),
fieldList);
}
/**
* Generates a function which takes a pointer to the parameter struct
* associated to this kernel and $reveal's any pointer parameters.
*/
public FunctionDefinitionNode generateArgRevealFunction() {
String srcMethod = entity.getName() + ".generateArgRevealFunction";
List<BlockItemNode> bodyList = new LinkedList<>();
for (VariableDeclarationNode formalDecl : generateFormalParameters(
entity.getName(), getDefinition().getTypeNode(), false)) {
ExpressionNode argNode = nodeExprOp(
srcMethod, OperatorNode.Operator.ADDRESSOF, nodeExprArrow(srcMethod,
nodeExprId(srcMethod, "args"),
formalDecl.getName()));
bodyList.add(nodeStmtCall(srcMethod, "$reveal",
nodeExprCast(srcMethod,
nodeTypePointer(srcMethod, voidType()),
argNode)));
}
return nodeDefnFunction(srcMethod, getArgRevealFunctionName(),
voidType(),
Arrays.asList(
nodeDeclVar(srcMethod, "args",
nodeTypePointer(srcMethod,
nodeTypeNamed(srcMethod,
getParamStructName())))),
bodyList);
}
private FunctionDefinitionNode generateKernelThreadDefinition() {
String srcMethod = entity.getName()
+ ".generateKernelThreadDefinition";
List<BlockItemNode> bodyList = new LinkedList<>();
bodyList.add(nodeStmtCall(srcMethod, "$local_start"));
bodyList.add(nodeDeclVarInit(srcMethod, "$cuda_tid",
nodeTypeInt(srcMethod),
nodeExprCall(srcMethod, "$cuda_dim3_index",
nodeExprId(srcMethod, "blockDim"),
nodeExprId(srcMethod, "threadIdx"))));
bodyList.add(nodeDeclVarInit(srcMethod, "$cuda_kid",
nodeTypeInt(srcMethod),
nodeExprCall(srcMethod, "$cuda_kernel_index",
nodeExprId(srcMethod, "gridDim"),
nodeExprId(srcMethod, "blockDim"),
nodeExprId(srcMethod, "blockIdx"),
nodeExprId(srcMethod, "threadIdx"))));
bodyList.add(nodeDeclVarInit(srcMethod, "$thread",
nodeTypeNamed(srcMethod, "$cuda_thread_data_t"),
nodeExprCall(srcMethod, "$create_cuda_thread_data",
nodeExprHere(srcMethod),
nodeExprId(srcMethod, "$kernel"),
nodeExprOp(srcMethod, Operator.DIV,
nodeExprId(srcMethod, "$cuda_kid"),
nodeExprOp(srcMethod, Operator.TIMES,
nodeExprDot(srcMethod,
nodeExprId(srcMethod,
"blockDim"),
"x"),
nodeExprOp(srcMethod,
Operator.TIMES,
nodeExprDot(srcMethod,
nodeExprId(
srcMethod,
"blockDim"),
"y"),
nodeExprDot(srcMethod,
nodeExprId(
srcMethod,
"blockDim"),
"z")))),
nodeExprOp(srcMethod, Operator.DIV,
nodeExprId(srcMethod, "$cuda_tid"),
nodeExprId(srcMethod, "warpSize")),
nodeExprOp(srcMethod, Operator.MOD,
nodeExprId(srcMethod, "$cuda_tid"),
nodeExprId(srcMethod, "warpSize")))));
for (BlockItemNode stmt : kernelBody) {
if (stmt != null) {
bodyList.add(stmt.copy());
}
}
bodyList.add(nodeStmtCall(srcMethod, "$destroy_cuda_thread_data",
nodeExprId(srcMethod, "$thread")));
bodyList.add(nodeStmtCall(srcMethod, "$local_end"));
return nodeDefnFunction(srcMethod, "$cuda_thread", voidType(),
Arrays.asList(nodeDeclVar(srcMethod, "threadIdx",
nodeTypeNamed(srcMethod, "uint3"))),
bodyList);
}
private FunctionDefinitionNode generateKernelBlockDefinition() {
String srcMethod = entity.getName()
+ ".generateKernelBlockDefinition";
List<BlockItemNode> bodyList = new LinkedList<>();
List<VariableDeclarationNode> sharedVars = extractSharedVariableDeclarations(
entity.getDefinition().getBody());
completeSharedExternArrays(sharedVars);
bodyList.addAll(sharedVars);
bodyList.add(generateKernelThreadDefinition());
bodyList.add(nodeStmtCall(srcMethod, "$cuda_run_and_wait_on_procs",
nodeExprId(srcMethod, "blockDim"),
nodeExprId(srcMethod, "$cuda_thread")));
return nodeDefnFunction(srcMethod, "$cuda_block", voidType(),
Arrays.asList(nodeDeclVar(srcMethod, "blockIdx",
nodeTypeNamed(srcMethod, "uint3"))),
bodyList);
}
/**
* Generates a transformed definition of this kernel to emulate the
* thread hierarchy of CUDA kernels as well as to inject data race
* checks.
*/
public FunctionDefinitionNode generateTransformedKernelDefinition() {
String srcMethod = entity.getName()
+ ".generateTransformedKernelDefinition";
kernelBody = entity.getDefinition().getBody();
// We transform the kernel body first because it also scans the body
// for information we need ahead of time like the number of
// __syncthreads() calls.
transformBodyOfCudaFunction(kernelBody);
List<BlockItemNode> bodyList = new LinkedList<>();
// Need to initialize in order to pass in to $cuda_syncthreads
bodyList.add(nodeDeclVarInit(srcMethod, "$kernel",
nodeTypeNamed(srcMethod, "$cuda_kernel_data_t"),
nodeExprCall(srcMethod, "$create_cuda_kernel_data",
nodeExprHere(srcMethod),
nodeExprId(srcMethod, "gridDim"),
nodeExprId(srcMethod, "blockDim"))));
bodyList.add(generateKernelBlockDefinition());
bodyList.add(nodeStmtCall(srcMethod, "$cuda_run_and_wait_on_procs",
nodeExprId(srcMethod, "gridDim"),
nodeExprId(srcMethod, "$cuda_block")));
bodyList.add(nodeStmtCall(srcMethod, "$destroy_cuda_kernel_data",
nodeExprId(srcMethod, "$kernel")));
return nodeDefnFunction(srcMethod,
transformCudaFunctionName(entity.getName()), voidType(),
generateFormalParameters(entity.getName(),
getDefinition().getTypeNode(), false),
bodyList);
}
/**
* Generates a transformed declaration of this kernel which uses the new
* transformed name and includes the context parameters as regular
* formal parameters.
*/
public FunctionDeclarationNode generateTransformedKernelDeclaration() {
String srcMethod = entity.getName()
+ ".generateTransformedKernelDeclaration";
return nodeDeclFunction(srcMethod,
transformCudaFunctionName(entity.getName()), voidType(),
generateFormalParameters(entity.getName(),
getDefinition().getTypeNode(), false));
}
public FunctionDefinitionNode generateKernelProcDefinition() {
String srcMethod = entity.getName()
+ ".generateKernelProcDefinition";
List<BlockItemNode> bodyList = new LinkedList<>();
bodyList.add(nodeStmtWhen(srcMethod, nodeExprArrow(srcMethod,
nodeExprId(srcMethod, "opState"), "start")));
bodyList.add(nodeDeclVar(srcMethod, "args",
nodeTypeNamed(srcMethod, getParamStructName())));
bodyList.add(nodeStmtCall(srcMethod, "$message_unpack",
nodeExprId(srcMethod, "request"),
nodeExprOp(srcMethod, Operator.ADDRESSOF,
nodeExprId(srcMethod, "args")),
nodeExprSizeof(srcMethod,
nodeTypeNamed(srcMethod, getParamStructName()))));
bodyList.add(nodeStmtCall(srcMethod, getArgRevealFunctionName(),
nodeExprOp(srcMethod, Operator.ADDRESSOF,
nodeExprId(srcMethod, "args"))));
List<VariableDeclarationNode> formals = generateFormalParameters(
entity.getName(), getDefinition().getTypeNode(), false);
ExpressionNode kernelArgs[] = new ExpressionNode[formals.size()];
int i = 0;
for (VariableDeclarationNode paramDeclNode : formals) {
kernelArgs[i] = nodeExprDot(srcMethod,
nodeExprId(srcMethod, "args"), paramDeclNode.getName());
i++;
}
bodyList.add(nodeStmtCall(srcMethod,
transformCudaFunctionName(entity.getName()), kernelArgs));
bodyList.add(nodeStmtCall(srcMethod, "$stream_dequeue",
nodeExprId(srcMethod, "cudaStream")));
return nodeDefnFunction(srcMethod, getKernelProcName(), voidType(),
Arrays.asList(
nodeDeclVar(srcMethod, "request",
nodeTypeNamed(srcMethod, "$message")),
nodeDeclVar(srcMethod, "opState",
nodeTypeNamed(srcMethod,
"$cuda_op_state_t")),
nodeDeclVar(srcMethod, "cudaStream",
nodeTypeNamed(srcMethod, "cudaStream_t"))),
bodyList);
}
/**
* Generates a new launch function definition which the host will call
* in place of a kernel launch.
*
* This function takes the context parameters of the kernel followed by
* the kernel's formal parameters. It then puts all of these parameters
* into an instance of the kernel's parameter structure. Then it packs
* this object into a message and sends it to the device. Then it waits
* for a response from the device before moving on.
*
* @return
*/
public FunctionDefinitionNode generateKernelLaunchFunction() {
String srcMethod = entity.getName()
+ ".generateKernelLaunchFunction";
List<VariableDeclarationNode> formals = generateFormalParameters(
entity.getName(), getDefinition().getTypeNode(), false);
List<BlockItemNode> bodyList = new LinkedList<>();
bodyList.add(nodeDeclVar(srcMethod, "args",
nodeTypeNamed(srcMethod, getParamStructName())));
for (VariableDeclarationNode formal : formals) {
bodyList.add(nodeStmtAssign(srcMethod,
nodeExprDot(srcMethod, nodeExprId(srcMethod, "args"),
formal.getName()),
nodeExprId(srcMethod, formal.getName())));
}
ExpressionNode messagePackExpr = nodeExprCall(srcMethod,
"$message_pack", nodeExprId(srcMethod, HOST_PLACE_NAME),
nodeExprId(srcMethod, DEVICE_PLACE_NAME),
nodeFactory.newEnumerationConstantNode(
nodeIdent(srcMethod, getTagName())),
nodeExprOp(srcMethod, Operator.ADDRESSOF,
nodeExprId(srcMethod, "args")),
nodeExprSizeof(srcMethod,
nodeTypeNamed(srcMethod, getParamStructName())));
bodyList.add(nodeStmtCall(srcMethod, "$comm_enqueue",
nodeExprId(srcMethod, HOST_COMM_NAME), messagePackExpr));
bodyList.add(nodeStmtCall(srcMethod, "$comm_dequeue",
nodeExprId(srcMethod, HOST_COMM_NAME),
nodeExprId(srcMethod, DEVICE_PLACE_NAME),
nodeFactory.newEnumerationConstantNode(
nodeIdent(srcMethod, getTagName()))));
return nodeDefnFunction(srcMethod, getLaunchFunctionName(),
voidType(), formals, bodyList);
}
}
public Cuda2CIVLWorker(ASTFactory astFactory) {
super(Cuda2CIVLTransformer.LONG_NAME, astFactory);
identifierPrefix = "_cuda_";
}
@Override
protected AST transformCore(AST ast) throws SyntaxException {
if (!hasHeader(ast, CUDA_HEADER))
return ast;
SequenceNode<BlockItemNode> root = ast.getRootNode();
ast.release();
scanTree(root);
assert cudaTagEnumType != null;
addEnumTags();
translateDeviceFunctions();
executeKernelTransformations();
executeKernelCallTransformations();
translateCudaMallocCalls(root);
if (!has_gen_mainFunction(root)) {
transformMainFunction(root);
createNewMainFunction(root);
}
translateMainDefinition(root);
// translateKernelDefinitions(root);
// translateKernelDeclarations(root);
AST newAST = astFactory.newAST(root, ast.getSourceFiles(),
ast.isWholeProgram());
//newAST.prettyPrint(System.out, false);
return newAST;
}
/**
* Returns a new temporary variable each time it is called.
*
* @return A generated temporary variable name
*/
protected String newTemporaryVariableName() {
return identifierPrefix + "tmp" + tempVarNum++;
}
/**
* Scans the entire AST recursively, collecting nodes and entities of
* interest. Also performs some light transformations such as removing all
* of the built-in CUDA variables.
*
* @param root
* The root of the AST to be scanned.
*/
protected void scanTree(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
switch (child.nodeKind()) {
case TYPE :
if (cudaTagEnumType == null && ((TypeNode) child)
.kind() == TypeNodeKind.ENUMERATION) {
EnumerationTypeNode enumTypeNode = (EnumerationTypeNode) child;
if (enumTypeNode.getTag().name()
.equals(CUDA_TAG_ENUM_NAME)) {
cudaTagEnumType = enumTypeNode.getType();
}
}
break;
case FUNCTION_DEFINITION :
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.hasDeviceFunctionSpecifier() && definition
.getTypeNode() instanceof FunctionTypeNode) {
String funcName = definition.getName();
if (funcName == null) {
throw new CIVLSyntaxException(
"__device__ functions cannot be anonymous",
definition.getSource());
}
addDeviceFunction(funcName, definition.getEntity());
}
if (definition.hasGlobalFunctionSpecifier()) {
String kernelName = definition.getName();
if (kernelName == null) {
throw new CIVLSyntaxException(
"CUDA kernels cannot be anonymous",
definition.getSource());
}
addKernel(kernelName, definition.getEntity());
}
break;
case FUNCTION_DECLARATION :
FunctionDeclarationNode declaration = (FunctionDeclarationNode) child;
if (declaration.hasDeviceFunctionSpecifier() && declaration
.getTypeNode() instanceof FunctionTypeNode) {
String funcName = declaration.getName();
if (funcName == null) {
throw new CIVLSyntaxException(
"__device__ functions cannot be anonymous",
declaration.getSource());
}
addDeviceFunction(funcName, declaration.getEntity());
}
if (declaration.hasGlobalFunctionSpecifier() && declaration
.getTypeNode() instanceof FunctionTypeNode) {
String kernelName = declaration.getName();
if (kernelName == null) {
throw new CIVLSyntaxException(
"CUDA kernels cannot be anonymous",
declaration.getSource());
}
addKernel(kernelName, declaration.getEntity());
}
break;
case VARIABLE_DECLARATION :
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) child;
if (variableDeclaration.getIdentifier() != null
&& variableDeclaration.getIdentifier().getSource()
.getFirstToken().getSourceFile().getName()
.equals("cuda.h")
&& builtinCUDAVariables.contains(variableDeclaration
.getIdentifier().name())) {
// variableDeclaration.remove();
continue;
}
break;
case STATEMENT :
StatementNode statement = (StatementNode) child;
if (statement.statementKind() == StatementKind.EXPRESSION) {
ExpressionStatementNode expressionStatement = (ExpressionStatementNode) statement;
ExpressionNode expression = expressionStatement
.getExpression();
if (expression
.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall
.getNumberOfContextArguments() > 0) {
if (functionCall.getFunction()
.expressionKind() != ExpressionKind.IDENTIFIER_EXPRESSION) {
throw new CIVLSyntaxException(
"CUDA kernel calls must be made with kernel identifier explicitly",
functionCall.getFunction()
.getSource());
}
kernelCalls.add(expressionStatement);
break;
}
}
}
default :
break;
}
scanTree(child);
}
}
/**
* Gets the {@link KernelInfo} associated with kernelName if it exists. If
* no such KernelInfo exists then a fresh one is created and associated to
* kernelName and returned.
*
* @param kernelName
* @return the (potentially fresh) associated KernelInfo
*/
private KernelInfo addKernel(String kernelName, Function kernelEntity) {
KernelInfo kernelEntry = kernelMap.get(kernelName);
if (kernelEntry == null) {
kernelEntry = new KernelInfo(kernelEntity);
kernelMap.put(kernelName, kernelEntry);
} else {
assert kernelEntry.entity == kernelEntity;
}
return kernelEntry;
}
private Function addDeviceFunction(String funcName, Function entity) {
return deviceFunctionMap.put(funcName, entity);
}
private void addEnumTags() {
SequenceNode<EnumeratorDeclarationNode> enumValues = ((EnumerationTypeNode) cudaTagEnumType
.getDefinition()).enumerators();
for (KernelInfo kernel : kernelMap.values()) {
enumValues.addSequenceChild(
nodeDeclEnumerator("addEnumTags", kernel.getTagName()));
}
}
private void translateDeviceFunctions() {
for (Function devFunc : deviceFunctionMap.values()) {
FunctionDefinitionNode definition = devFunc.getDefinition();
for (DeclarationNode decl : devFunc.getDeclarations()) {
FunctionDeclarationNode funcDecl = (FunctionDeclarationNode) decl;
if (funcDecl == definition)
continue;
ASTNode declParent = funcDecl.parent();
int index = funcDecl.childIndex();
funcDecl.remove();
declParent.setChild(index, transformDeviceFunctionDeclaration(
funcDecl, definition.getTypeNode()));
}
ASTNode parentNode = definition.parent();
int index = definition.childIndex();
definition.remove();
parentNode.setChild(index,
transformDeviceFunctionDefinition(definition));
}
}
private DeclarationNode transformDeviceFunctionDeclaration(
FunctionDeclarationNode declNode, FunctionTypeNode funcTypeNode) {
String srcMethod = declNode.getName()
+ ".transformDeviceFunctionDeclaration";
return nodeDeclFunction(srcMethod,
transformCudaFunctionName(declNode.getName()),
funcTypeNode.getReturnType().copy(), generateFormalParameters(
declNode.getName(), funcTypeNode, true));
}
// Each pair is param name followed by named type
private List<Pair<String, String>> contextParams(
boolean includeDeviceParams) {
List<Pair<String, String>> params = new LinkedList<>();
params.addAll(Arrays.asList(new Pair<String, String>("gridDim", "dim3"),
new Pair<String, String>("blockDim", "dim3"),
new Pair<String, String>("_cuda_mem_size", "size_t"),
new Pair<String, String>("_cuda_stream", "cudaStream_t")));
if (includeDeviceParams) {
params.addAll(
Arrays.asList(new Pair<String, String>("blockIdx", "dim3"),
new Pair<String, String>("threadIdx", "dim3"),
new Pair<String, String>("$thread",
"$cuda_thread_data_t")));
}
return params;
}
/**
* Generates a list of {@link VariableDeclarationNode}s which are the full
* list of formal parameters to the kernel. This includes the implicit
* context parameters that all CUDA kernels have.
*
* @param srcMethod
* The name of the method generating these parameters. Used for
* source generation purposes.
* @return
*/
public List<VariableDeclarationNode> generateFormalParameters(
String funcName, FunctionTypeNode funcTypeNode, boolean isDevice) {
String srcMethod = funcName + ".generateFormalParameters";
List<Pair<String, String>> contextParams = contextParams(isDevice);
List<VariableDeclarationNode> formals = new LinkedList<VariableDeclarationNode>();
for (Pair<String, String> contextParam : contextParams) {
formals.add(nodeDeclVar(srcMethod, contextParam.left,
nodeTypeNamed(srcMethod, contextParam.right)));
}
for (VariableDeclarationNode param : funcTypeNode.getParameters()) {
if (param.getTypeNode().kind() == TypeNodeKind.VOID)
continue;
formals.add(param.copy());
}
return formals;
}
private FunctionDefinitionNode transformDeviceFunctionDefinition(
FunctionDefinitionNode defNode) {
String funcName = defNode.getName();
String srcMethod = funcName + ".transformDeviceFunctionDefinition";
CompoundStatementNode body = defNode.getBody();
transformBodyOfCudaFunction(body);
List<BlockItemNode> bodyList = new LinkedList<>();
for (BlockItemNode stmt : body) {
if (stmt != null) {
bodyList.add(stmt.copy());
}
}
return nodeDefnFunction(srcMethod, transformCudaFunctionName(funcName),
defNode.getTypeNode().getReturnType().copy(),
generateFormalParameters(funcName, defNode.getTypeNode(), true),
bodyList);
}
private String transformCudaFunctionName(String funcName) {
return "$cuda_" + funcName;
}
private void transformBodyOfCudaFunction(ASTNode node) {
String srcMethod = "transformBodyOfCudaFunction";
for (ASTNode child : node.children()) {
if (child == null)
continue;
if (child instanceof ExpressionNode) {
ExpressionNode exprNode = (ExpressionNode) child;
if (exprNode instanceof FunctionCallNode) {
FunctionCallNode funcCallNode = (FunctionCallNode) exprNode;
ExpressionNode function = funcCallNode.getFunction();
if (function instanceof IdentifierExpressionNode) {
String functionName = ((IdentifierExpressionNode) function)
.getIdentifier().name();
if (functionName.startsWith("__syncthreads")) {
List<ExpressionNode> args = new LinkedList<>();
for (ExpressionNode arg : funcCallNode
.getArguments()) {
args.add(arg.copy());
}
args.add(nodeExprId(srcMethod, "$thread"));
args.add(nodeExprInt(srcMethod, numSyncthreads));
FunctionCallNode newSyncthreads = nodeExprCall(
srcMethod, "$cuda" + functionName,
args.toArray(new ExpressionNode[0]));
numSyncthreads++;
int index = funcCallNode.childIndex();
funcCallNode.remove();
node.setChild(index, newSyncthreads);
continue;
}
Function devFunction = deviceFunctionMap
.get(functionName);
if (devFunction != null) {
List<ExpressionNode> args = new LinkedList<>();
for (Pair<String, String> contextParam : contextParams(
true)) {
args.add(nodeExprId(srcMethod,
contextParam.left));
}
for (ExpressionNode arg : funcCallNode
.getArguments()) {
args.add(arg.copy());
}
int index = funcCallNode.childIndex();
funcCallNode.remove();
node.setChild(index, nodeExprCall(srcMethod,
transformCudaFunctionName(functionName),
args.toArray(new ExpressionNode[0])));
}
}
}
}
transformBodyOfCudaFunction(child);
}
}
/**
* Performs all transformations related to a kernel entity. This includes:
* <ul>
* <li>Generating a parameter struct definition associated to the kernel;
* <li>Generating a launch function that the host will use to launch
* instances of this kernel;
* <li>Generating a "process" function which the device will $spawn a $proc
* with which represents the CUDA op that executes the kernel
* <li>Transforming the declaration(s) of this kernel to use the new
* (mangled) name;
* <li>Transforming the definition of this kernel to emulate the
* hierarchical thread structure of CUDA and inject data race checks
* </ul>
*/
private void executeKernelTransformations() {
for (KernelInfo kernel : kernelMap.values()) {
FunctionDefinitionNode kernelDefinition = kernel.getDefinition();
boolean firstDecl = true;
for (DeclarationNode decl : kernel.entity.getDeclarations()) {
if (firstDecl) {
@SuppressWarnings("unchecked")
SequenceNode<BlockItemNode> firstDeclParentNode = (SequenceNode<BlockItemNode>) decl
.parent();
firstDeclParentNode.insertChildren(decl.childIndex(),
Arrays.asList(kernel.generateParameterStruct(),
kernel.generateArgRevealFunction(),
kernel.generateKernelLaunchFunction()));
firstDeclParentNode.insertChildren(decl.childIndex() + 1,
Arrays.asList(
kernel.generateKernelProcDefinition()));
firstDecl = false;
}
if (decl == kernelDefinition)
continue;
ASTNode declParent = decl.parent();
int index = decl.childIndex();
decl.remove();
declParent.setChild(index,
kernel.generateTransformedKernelDeclaration());
}
@SuppressWarnings("unchecked")
SequenceNode<BlockItemNode> parentNode = (SequenceNode<BlockItemNode>) kernelDefinition
.parent();
int index = kernelDefinition.childIndex();
kernelDefinition.remove();
parentNode.setChild(index,
kernel.generateTransformedKernelDefinition());
}
}
/**
* Replaces kernel launches (that use the `kernel<<< context args
* >>>(regular args)` syntax) with calls to the corresponding launch
* function, generated earlier.
*/
private void executeKernelCallTransformations() {
String srcMethod = "executeKernelCallTransformations";
for (ExpressionStatementNode kernelCallStatement : kernelCalls) {
ASTNode parent = kernelCallStatement.parent();
FunctionCallNode kernelCallNode = (FunctionCallNode) kernelCallStatement
.getExpression();
String kernelName = ((IdentifierExpressionNode) kernelCallNode
.getFunction()).getIdentifier().name();
KernelInfo kernel = kernelMap.get(kernelName);
List<ExpressionNode> launchArgList = getContextArgList(srcMethod,
kernelCallNode);
for (ExpressionNode argument : kernelCallNode.getArguments()) {
launchArgList.add(argument.copy());
}
parent.setChild(kernelCallStatement.childIndex(),
nodeStmtCall(srcMethod, kernel.getLaunchFunctionName(),
launchArgList.toArray(new ExpressionNode[0])));
}
}
private List<ExpressionNode> getContextArgList(String srcMethod,
FunctionCallNode kernelCall) {
List<ExpressionNode> contextArgs = new ArrayList<>();
for (int i = 0; i < 2; i++) {
ExpressionNode arg = kernelCall.getContextArgument(i);
Type argType = arg.getConvertedType();
if (argType.kind() == TypeKind.QUALIFIED) {
argType = ((QualifiedObjectType) argType).getBaseType();
}
if (argType.kind() == TypeKind.BASIC
&& ((StandardBasicType) argType)
.getBasicTypeKind() == BasicTypeKind.INT) {
contextArgs.add(
nodeExprCall(srcMethod, "$cuda_to_dim3", arg.copy()));
} else {
contextArgs.add(arg.copy());
}
}
int numContextArgs = kernelCall.getNumberOfContextArguments();
contextArgs.add(numContextArgs < 3
? nodeExprInt(srcMethod, 0)
: kernelCall.getContextArgument(2).copy());
contextArgs.add(numContextArgs < 4
? nodeExprNullPointer(srcMethod)
: kernelCall.getContextArgument(3).copy());
return contextArgs;
}
/**
* Alters a body of code by removing any variable declaration with the
* "__shared__" tag and returning a new list of those removed declarations
* without the "__shared__" tag.
*
* @param statements
* a CompountStatementNode that is any section of code
* @return The list of removed variable declarations
*/
private List<VariableDeclarationNode> extractSharedVariableDeclarations(
CompoundStatementNode statements) {
List<VariableDeclarationNode> declarations = new ArrayList<>();
for (BlockItemNode item : statements) {
if (item instanceof VariableDeclarationNode) {
VariableDeclarationNode variableDeclaration = (VariableDeclarationNode) item
.copy();
if (variableDeclaration.hasSharedStorage()) {
statements.removeChild(item.childIndex());
variableDeclaration.setSharedStorage(false);
declarations.add(variableDeclaration);
}
}
}
return declarations;
}
/**
*
* @param sharedVars
* a list of VariableDeclarationNodes
*/
protected void completeSharedExternArrays(
List<VariableDeclarationNode> sharedVars) {
for (VariableDeclarationNode node : sharedVars) {
if (node.hasExternStorage()
&& node.getTypeNode().kind() == TypeNodeKind.ARRAY) {
ArrayTypeNode arrayType = (ArrayTypeNode) node.getTypeNode();
if (arrayType.getExtent() == null) {
arrayType.setExtent(identifierExpression("_cuda_mem_size"));
node.setExternStorage(false);
}
}
}
}
/**
* Transforms every cuda malloc function call using
* {@link Cuda2CIVLWorker#cudaMallocTransform(FunctionCallNode)}. Cuda
* malloc calls are found by recursively searching through the AST for a
* matching function call.
*
* @param root
* the root node of an Abstract Syntax Tree
*/
private void translateCudaMallocCalls(ASTNode root) {
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.EXPRESSION) {
ExpressionNode expression = (ExpressionNode) child;
if (expression
.expressionKind() == ExpressionKind.FUNCTION_CALL) {
FunctionCallNode functionCall = (FunctionCallNode) expression;
if (functionCall.getFunction()
.expressionKind() == ExpressionKind.IDENTIFIER_EXPRESSION) {
IdentifierExpressionNode identifierExpression = (IdentifierExpressionNode) functionCall
.getFunction();
if (identifierExpression.getIdentifier().name()
.equals("cudaMalloc")) {
int index = functionCall.childIndex();
/*
* ASTNode parent = root; while(parent.nodeKind() !=
* NodeKind.STATEMENT) { parent = parent.parent(); }
* StatementNode parentStatement =
* (StatementNode)parent;
*/
root.setChild(index,
cudaMallocTransform(functionCall));
continue;
}
}
}
}
translateCudaMallocCalls(child);
}
}
/**
* Translates "cudaMalloc( (void**) ptrPtr, size)" to "*ptrPtr =
* (type)$malloc($root, size), cudaSuccess" where "type" is the type of
* *ptrPtr
*
* @param cudaMallocCall
* a FunctionCallNode which is a call to cuda malloc
* @return The translated cuda malloc call as an expression node
*/
private ExpressionNode cudaMallocTransform(
FunctionCallNode cudaMallocCall) {
Source source = cudaMallocCall.getSource();
/*
* TypeNode scope = nodeFactory.newScopeTypeNode(source);
* VariableDeclarationNode deviceScopeDeclaration =
* nodeFactory.newVariableDeclarationNode(source,
* identifier("deviceScope"), scope);
* deviceScopeDeclaration.setInitializer(nodeFactory.newFunctionCallNode
* (source, identifierExpression("$cuda_host_request_device_scope"),
* null, null));
*/
FunctionCallNode request_device = nodeFactory.newFunctionCallNode(
source, identifierExpression("$cuda_host_request_device_scope"),
Collections.<ExpressionNode>emptyList(), null);
// find the pointer
ExpressionNode ptrPtr = cudaMallocCall.getArgument(0);
while (ptrPtr instanceof CastNode) {
ptrPtr = ((CastNode) ptrPtr).getArgument();
}
ExpressionNode size = cudaMallocCall.getArgument(1);
// build lhs expression
ExpressionNode assignLhs = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.DEREFERENCE,
Arrays.asList(ptrPtr.copy()));
Type lhsType;
if (ptrPtr.getInitialType().kind() == TypeKind.POINTER) {
PointerType ptrType = (PointerType) ptrPtr.getInitialType();
lhsType = ptrType.referencedType();
} else {
lhsType = ptrPtr.getInitialType();
}
// build rhs expression
FunctionCallNode mallocCall = nodeFactory.newFunctionCallNode(source,
identifierExpression(source, "$malloc"),
Arrays.asList(request_device, size.copy()), null);
CastNode mallocCast = nodeFactory.newCastNode(source,
typeNode(source, lhsType), mallocCall);
// create assign node
OperatorNode assignment = nodeFactory.newOperatorNode(
cudaMallocCall.getSource(), Operator.ASSIGN,
Arrays.asList(assignLhs,
nodeFactory.newFunctionCallNode(source,
identifierExpression("$hide"),
Arrays.asList(mallocCast), null)));
/*
* List<BlockItemNode> transformedItems = new
* ArrayList<BlockItemNode>();
* transformedItems.add(deviceScopeDeclaration);
* transformedItems.add(nodeFactory.newExpressionStatementNode(
* assignment));
*/
// CompoundStatementNode newStatements =
// nodeFactory.newCompoundStatementNode(source, transformedItems);
// create comma node
ExpressionNode finalExpression = nodeFactory.newOperatorNode(source,
Operator.COMMA,
Arrays.asList(assignment,
nodeFactory.newEnumerationConstantNode(nodeFactory
.newIdentifierNode(source, "cudaSuccess"))));
// mallocStatement.parent().setChild(mallocStatement.childIndex(),
// newStatements);
// Need to figure out how to put this compound statement node above the
// statement with the cudaMalloc expression
// Maybe I have to copy all the statements like with the atomics
// Return (cuda_C =
// $hide((float*)$malloc($cuda_host_request_device_scope(),
// size)), cudaSuccess)
return finalExpression;
}
private FunctionDefinitionNode createCudaMain() {
String srcMethod = "createCudaMain";
List<BlockItemNode> cudaMainBody = new ArrayList<BlockItemNode>();
createCudaMainGlobalVariables(cudaMainBody);
createDefaultStreamIfNullFunc(cudaMainBody);
createCudaMainWhileLoop(cudaMainBody);
return nodeDefnFunction(srcMethod, CUDA_MAIN, voidType(),
Arrays.asList(), cudaMainBody);
}
private void createCudaMainGlobalVariables(List<BlockItemNode> body) {
String srcMethod = "createCudaMainGlobalVariables";
body.add(nodeDeclVarInit(srcMethod, "$cuda_scope",
nodeTypeScope(srcMethod), nodeExprHere(srcMethod)));
body.add(nodeDeclVarInit(srcMethod, "$cuda_device_comm",
nodeTypeNamed(srcMethod, "$comm"),
nodeExprCall(srcMethod, "$comm_create",
nodeExprId(srcMethod, "$cuda_scope"),
nodeExprId(srcMethod, "$cuda_gcomm"),
nodeExprInt(srcMethod, 1))));
body.add(nodeDeclVar(srcMethod, DEVICE_GLOB_CONTEXT_NAME,
nodeTypeNamed(srcMethod, "$cuda_context")));
body.add(nodeDeclVar(srcMethod, "$cuda_default_stream",
nodeTypeNamed(srcMethod, "cudaStream_t")));
body.add(nodeDeclVarInit(srcMethod, "defaultStreamNode",
nodeTypeNamed(srcMethod, "$cuda_stream_node_t"),
nodeExprCall(srcMethod, "$create_new_stream_node",
nodeExprId(srcMethod, "$cuda_scope"))));
body.add(nodeStmtAssign(CUDA_MAIN,
nodeExprId(srcMethod, "$cuda_default_stream"),
nodeExprArrow(srcMethod,
nodeExprId(srcMethod, "defaultStreamNode"), "stream")));
body.add(nodeStmtAssign(CUDA_MAIN, nodeExprDot(srcMethod,
nodeExprId(srcMethod, DEVICE_GLOB_CONTEXT_NAME), "headNode"),
nodeExprId(srcMethod, "defaultStreamNode")));
body.add(nodeStmtAssign(CUDA_MAIN, nodeExprDot(srcMethod,
nodeExprId(srcMethod, DEVICE_GLOB_CONTEXT_NAME), "numStreams"),
nodeExprInt(srcMethod, 1)));
}
private void createDefaultStreamIfNullFunc(List<BlockItemNode> body) {
String srcMethod = "createDefaultStreamIfNullFunc";
List<BlockItemNode> defaultStreamIfNullBody = new ArrayList<BlockItemNode>();
defaultStreamIfNullBody.add(nodeFactory.newReturnNode(
newSource(srcMethod, CivlcTokenConstant.RETURN),
nodeExprOp(srcMethod, Operator.CONDITIONAL,
nodeExprOp(srcMethod, Operator.EQUALS,
nodeExprId(srcMethod, "stream"),
nodeExprNullPointer(srcMethod)),
nodeExprId(srcMethod, "$cuda_default_stream"),
nodeExprId(srcMethod, "stream"))));
body.add(nodeDefnFunction(srcMethod, "$default_stream_if_null",
nodeTypeNamed(srcMethod, "cudaStream_t"),
Arrays.asList(nodeDeclVar(srcMethod, "stream",
nodeTypeNamed(srcMethod, "cudaStream_t"))),
defaultStreamIfNullBody));
}
private void createCudaMainWhileLoop(List<BlockItemNode> body) {
String srcMethod = "createCudaMainWhileLoop";
List<BlockItemNode> loopBody = new ArrayList<BlockItemNode>();
// TODO: Fix COMM_ANY_TAG issue, for now I will just use the int -2
loopBody.add(nodeDeclVarInit(srcMethod, "request",
nodeTypeNamed(srcMethod, "$message"),
nodeExprCall(srcMethod, "$comm_dequeue",
nodeExprId(srcMethod, "$cuda_device_comm"),
nodeExprId(srcMethod, "$CUDA_PLACE_HOST"),
nodeExprInt(srcMethod, -2))));
loopBody.add(nodeDeclVar(srcMethod, "response",
nodeTypeNamed(srcMethod, "$message")));
loopBody.add(nodeDeclVarInit(srcMethod, "tag", nodeTypeInt(srcMethod),
nodeExprCall(srcMethod, "$message_tag",
nodeExprId(srcMethod, "request"))));
createCudaMainSwitchStatement(loopBody);
loopBody.add(nodeStmtCall(srcMethod, "$comm_enqueue",
nodeExprId(srcMethod, "$cuda_device_comm"),
nodeExprId(srcMethod, "response")));
body.add(nodeFactory.newWhileLoopNode(
newSource(srcMethod, CivlcTokenConstant.WHILE),
booleanConstant(true),
nodeFactory.newCompoundStatementNode(
newSource(srcMethod,
CivlcTokenConstant.COMPOUND_STATEMENT),
loopBody),
null));
}
private void createCudaMainSwitchStatement(List<BlockItemNode> body) {
String srcMethod = "createCudaMainSwitchStatement";
ArrayList<BlockItemNode> switchBody = new ArrayList<BlockItemNode>();
//// $CUDA_TAG_SCOPE_REQUEST ////
List<BlockItemNode> cudaTagScopeRequestBody = new ArrayList<BlockItemNode>();
StatementNode scopeRequestAssignment = nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"),
nodeExprCall(srcMethod, "$message_pack",
nodeExprId(srcMethod, DEVICE_PLACE_NAME),
nodeExprId(srcMethod, HOST_PLACE_NAME),
nodeFactory.newEnumerationConstantNode(
identifier("$CUDA_TAG_SCOPE_REQUEST")),
nodeExprOp(srcMethod, Operator.ADDRESSOF,
nodeExprId(srcMethod, "$cuda_scope")),
nodeExprSizeof(srcMethod, nodeTypeScope(srcMethod))));
cudaTagScopeRequestBody.add(scopeRequestAssignment);
cudaTagScopeRequestBody.add(nodeBreak(srcMethod));
StatementNode cudaTagScopeRequestLabel = nodeSwitchLabeledStmt(
srcMethod, "$CUDA_TAG_SCOPE_REQUEST", cudaTagScopeRequestBody);
switchBody.add(cudaTagScopeRequestLabel);
//// $CUDA_TAG_cudaFree ////
List<BlockItemNode> cudaTagCudaFreeBody = new ArrayList<BlockItemNode>();
StatementNode cudaFreeAssignment = nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"), nodeExprCall(srcMethod,
"$cuda_free", nodeExprId(srcMethod, "request")));
cudaTagCudaFreeBody.add(cudaFreeAssignment);
cudaTagCudaFreeBody.add(nodeBreak(srcMethod));
StatementNode cudaTagCudaFreeLabel = nodeSwitchLabeledStmt(srcMethod,
"$CUDA_TAG_cudaFree", cudaTagCudaFreeBody);
switchBody.add(cudaTagCudaFreeLabel);
//// $CUDA_TAG_cudaMemcpy ////
List<BlockItemNode> cudaTagCudaMemcpyBody = new ArrayList<BlockItemNode>();
StatementNode cudaMemcpyAssignment = nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"),
nodeExprCall(srcMethod, "$cuda_memcpy",
nodeExprId(srcMethod, "$cuda_scope"),
nodeExprId(srcMethod, "$cuda_default_stream"),
nodeExprId(srcMethod, "request"),
booleanConstant(false)));
cudaTagCudaMemcpyBody.add(cudaMemcpyAssignment);
cudaTagCudaMemcpyBody.add(nodeBreak(srcMethod));
StatementNode cudaTagCudaMemcpyLabel = nodeSwitchLabeledStmt(srcMethod,
"$CUDA_TAG_cudaMemcpy", cudaTagCudaMemcpyBody);
switchBody.add(cudaTagCudaMemcpyLabel);
//// $CUDA_TAG_cudaMemcpyAsync ////
List<BlockItemNode> cudaTagCudaMemcpyAsyncBody = new ArrayList<BlockItemNode>();
StatementNode cudaMemcpyAsyncAssignment = nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"),
nodeExprCall(srcMethod, "$cuda_memcpy",
nodeExprId(srcMethod, "$cuda_scope"),
nodeExprId(srcMethod, "$cuda_default_stream"),
nodeExprId(srcMethod, "request"),
booleanConstant(true)));
cudaTagCudaMemcpyAsyncBody.add(cudaMemcpyAsyncAssignment);
cudaTagCudaMemcpyAsyncBody.add(nodeBreak(srcMethod));
StatementNode cudaTagCudaMemcpyAsyncLabel = nodeSwitchLabeledStmt(
srcMethod, "$CUDA_TAG_cudaMemcpyAsync",
cudaTagCudaMemcpyAsyncBody);
switchBody.add(cudaTagCudaMemcpyAsyncLabel);
//// $CUDA_TAG_cudaDeviceSynchronize ////
List<BlockItemNode> cudaTagCudaDeviceSynchronizeBody = new ArrayList<BlockItemNode>();
cudaTagCudaDeviceSynchronizeBody.add(nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"),
nodeExprCall(srcMethod, "$cuda_device_synchronize", nodeExprOp(
srcMethod, Operator.ADDRESSOF,
nodeExprId(srcMethod, DEVICE_GLOB_CONTEXT_NAME)))));
cudaTagCudaDeviceSynchronizeBody.add(nodeBreak(srcMethod));
StatementNode cudaTagCudaDeviceSynchronizeLabel = nodeSwitchLabeledStmt(
srcMethod, "$CUDA_TAG_cudaDeviceSynchronize",
cudaTagCudaDeviceSynchronizeBody);
switchBody.add(cudaTagCudaDeviceSynchronizeLabel);
//// $CUDA_TAG_LAUNCH_Kernel_X ////
for (KernelInfo kernel : kernelMap.values()) {
StatementNode cudaTagLaunchKernelLabel = generateCudaTagLaunchLabel(
kernel.getTagName(), kernel.getKernelProcName());
switchBody.add(cudaTagLaunchKernelLabel);
}
//// $CUDA_TAG_TEARDOWN ////
List<BlockItemNode> cudaTagCudaTeardownBody = new ArrayList<BlockItemNode>();
FunctionCallNode streamDestroyCall = nodeExprCall(srcMethod,
"$destroy_stream_node",
nodeExprArrow(srcMethod,
nodeExprId(srcMethod, "$cuda_default_stream"),
"containingNode"));
VariableDeclarationNode teardownProcDeclaration = nodeDeclVar(srcMethod,
"destructor", nodeTypeNamed(srcMethod, "$proc"));
teardownProcDeclaration.setInitializer(streamDestroyCall);
cudaTagCudaTeardownBody.add(teardownProcDeclaration);
StatementNode waitCall = nodeFactory
.newExpressionStatementNode(nodeExprCall(srcMethod, "$wait",
nodeExprId(srcMethod, "destructor")));
cudaTagCudaTeardownBody.add(waitCall);
StatementNode commDestroyCall = nodeFactory.newExpressionStatementNode(
nodeExprCall(srcMethod, "$comm_destroy",
nodeExprId(srcMethod, "$cuda_device_comm")));
cudaTagCudaTeardownBody.add(commDestroyCall);
cudaTagCudaTeardownBody.add(nodeFactory.newReturnNode(
newSource(srcMethod, CivlcTokenConstant.RETURN), null));
StatementNode cudaTagCudaTeardownLabel = nodeSwitchLabeledStmt(
srcMethod, "$CUDA_TAG_TEARDOWN", cudaTagCudaTeardownBody);
switchBody.add(cudaTagCudaTeardownLabel);
//// default ////
List<BlockItemNode> cudaTagDefaultBody = new ArrayList<BlockItemNode>();
String string = "\"" + "Unknown CUDA request" + "\"";
TokenFactory tokenFactory = astFactory.getTokenFactory();
Formation formation = tokenFactory
.newTransformFormation(transformerName, "stringLiteral");
CivlcToken ctoke = tokenFactory.newCivlcToken(
CivlcTokenConstant.STRING_LITERAL, string, formation,
TokenVocabulary.DUMMY);
StringToken stringToken;
StringLiteral literal = null;
try {
stringToken = tokenFactory.newStringToken(ctoke);
literal = stringToken.getStringLiteral();
} catch (SyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
StatementNode assertionStatement = nodeFactory
.newExpressionStatementNode(nodeExprCall(srcMethod, "$assert",
booleanConstant(false),
nodeFactory.newStringLiteralNode(
newSource(srcMethod,
CivlcTokenConstant.STRING_LITERAL),
string, literal)));
cudaTagDefaultBody.add(assertionStatement);
StatementNode cudaTagDefaultLabel = nodeSwitchLabeledStmt(srcMethod,
"default", cudaTagDefaultBody);
switchBody.add(cudaTagDefaultLabel);
//// Switch Statement ////
StatementNode switchStatement = nodeFactory.newSwitchNode(
newSource(srcMethod, CivlcTokenConstant.SWITCH),
nodeExprId(srcMethod, "tag"), nodeBlock(srcMethod, switchBody));
body.add(switchStatement);
}
private StatementNode nodeSwitchLabeledStmt(String srcMethod,
String caseName, List<BlockItemNode> body) {
Source labeledStmtSource = newSource(srcMethod,
CivlcTokenConstant.CASE_LABELED_STATEMENT);
SwitchLabelNode labelDecl;
if (caseName.equals("default")) {
labelDecl = nodeFactory
.newDefaultLabelDeclarationNode(labeledStmtSource, null);
} else {
labelDecl = nodeFactory.newCaseLabelDeclarationNode(
labeledStmtSource, nodeFactory.newEnumerationConstantNode(
identifier(caseName)),
null);
}
StatementNode label = nodeFactory.newLabeledStatementNode(
labeledStmtSource, labelDecl, nodeBlock(srcMethod, body));
return label;
}
private StatementNode generateCudaTagLaunchLabel(String caseName,
String procName) {
String srcMethod = "generateCudaTagLaunchLabel";
List<BlockItemNode> cudaTagLaunchBody = new ArrayList<BlockItemNode>();
StatementNode streamEnqueueCall = nodeFactory
.newExpressionStatementNode(nodeExprCall(srcMethod,
"$stream_enqueue", nodeExprId(srcMethod, "$cuda_scope"),
nodeExprId(srcMethod, "$cuda_default_stream"),
nodeExprId(srcMethod, "request"),
nodeExprId(srcMethod, procName)));
cudaTagLaunchBody.add(streamEnqueueCall);
StatementNode responseAssignment = nodeStmtAssign(srcMethod,
nodeExprId(srcMethod, "response"),
nodeExprCall(srcMethod, "$message_pack",
nodeExprId(srcMethod, DEVICE_PLACE_NAME),
nodeExprId(srcMethod, HOST_PLACE_NAME),
nodeExprId(srcMethod, "tag"), nullArgument(srcMethod),
nodeExprInt(srcMethod, 0)));
cudaTagLaunchBody.add(responseAssignment);
cudaTagLaunchBody.add(nodeBreak(srcMethod));
StatementNode cudaTagLaunchLabel = nodeSwitchLabeledStmt(srcMethod,
caseName, cudaTagLaunchBody);
return cudaTagLaunchLabel;
}
/**
* Finds the main function definition node underneath root and calls
* {@link Cuda2CIVLWorker#transformMainFunctionDefinition(FunctionDefinitionNode)}
* on it
*
* @param root
* the root node of an Abstract Syntax Tree
*/
private void translateMainDefinition(SequenceNode<BlockItemNode> root) {
String srcMethod = "translateMainDefinition";
for (ASTNode child : root.children()) {
if (child == null)
continue;
if (child.nodeKind() == NodeKind.FUNCTION_DEFINITION) {
FunctionDefinitionNode definition = (FunctionDefinitionNode) child;
if (definition.getName() != null
&& definition.getName().equals(GEN_MAIN)) {
FunctionDefinitionNode cudaDefinition = createCudaMain();
FunctionTypeNode cudaMainType = cudaDefinition.getTypeNode();
List<VariableDeclarationNode> cudaMainParams = new LinkedList<>();
for (VariableDeclarationNode param : cudaMainType.getParameters()) {
cudaMainParams.add(param.copy());
}
FunctionDeclarationNode cudaMainDecl = nodeDeclFunction(
srcMethod, cudaDefinition.getName(),
cudaMainType.getReturnType().copy(),
cudaMainParams);
FunctionDefinitionNode hostDefinition = definition.copy();
hostDefinition.setIdentifier(
nodeIdent(srcMethod, HOST_MAIN));
root.insertChildren(definition.childIndex(),
Arrays.asList(cudaMainDecl, hostDefinition));
root.addSequenceChild(cudaDefinition);
transformMainFunctionDefinition(definition);
return;
}
}
}
}
/**
* Spawns a host $proc on the "new" main function and a device $proc on the
* "cuda" main function. Waits for these processes and then destroys the
* $gcomm that these two processes communicate with.
*
* @param mainFunction
* the function definition node for the main function
*/
private void transformMainFunctionDefinition(FunctionDefinitionNode mainFunction) {
String srcMethod = "transformMainFunctionDefinition";
String hostProcName = "$host_proc" + newTemporaryVariableName();
String deviceProcName = "$cuda_proc" + newTemporaryVariableName();
List<BlockItemNode> newBody = new LinkedList<BlockItemNode>();
List<ExpressionNode> hostParams = new LinkedList<>();
for (VariableDeclarationNode mainParam : mainFunction.getTypeNode()
.getParameters()) {
hostParams.add(nodeExprId(srcMethod, mainParam.getName()));
}
newBody.add(nodeDeclVarInit(srcMethod, hostProcName,
nodeTypeNamed(srcMethod, "$proc"),
nodeFactory.newSpawnNode(
newSource(srcMethod, CivlcTokenConstant.SPAWN),
nodeExprCall(srcMethod, HOST_MAIN,
hostParams.toArray(new ExpressionNode[] {})))));
newBody.add(nodeDeclVarInit(srcMethod, deviceProcName,
nodeTypeNamed(srcMethod, "$proc"),
nodeFactory.newSpawnNode(
newSource(srcMethod, CivlcTokenConstant.SPAWN),
nodeExprCall(srcMethod, CUDA_MAIN))));
newBody.add(nodeStmtCall(srcMethod, "$wait",
nodeExprId(srcMethod, hostProcName)));
FunctionCallNode messagePackCall = this.nodeExprCall(srcMethod,
"$message_pack", this.nodeExprId(srcMethod, HOST_PLACE_NAME),
this.nodeExprId(srcMethod, DEVICE_PLACE_NAME),
nodeFactory.newEnumerationConstantNode(
this.identifier("$CUDA_TAG_TEARDOWN")),
this.nullArgument(srcMethod), this.nodeExprInt(srcMethod, 0));
FunctionCallNode commEnqueueCall = this.nodeExprCall(srcMethod,
"$comm_enqueue", this.nodeExprId(srcMethod, HOST_COMM_NAME),
messagePackCall);
FunctionCallNode commDestroyCall = this.nodeExprCall(srcMethod,
"$comm_destroy", this.nodeExprId(srcMethod, HOST_COMM_NAME));
newBody.add(nodeFactory.newExpressionStatementNode(commEnqueueCall));
newBody.add(nodeFactory.newExpressionStatementNode(commDestroyCall));
newBody.add(nodeStmtCall(srcMethod, "$wait",
nodeExprId(srcMethod, deviceProcName)));
newBody.add(nodeStmtCall(srcMethod, "$gcomm_destroy",
nodeExprId(srcMethod, "$cuda_gcomm"),
nodeExprNullPointer(srcMethod)));
mainFunction.setBody(nodeFactory.newCompoundStatementNode(
newSource(srcMethod, CivlcTokenConstant.COMPOUND_STATEMENT),
newBody));
}
private CastNode nullArgument(String srcMethod) {
CastNode NULL = nodeExprCast(srcMethod,
nodeTypePointer(srcMethod, voidType()),
nodeExprInt(srcMethod, 0));
return NULL;
}
}