Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 190 additions & 13 deletions src/OpenACCASTConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "acclexer.h"
#include "accparser.h"
#include <antlr4-runtime.h>
#include <algorithm>
#include <memory>

OpenACCDirective *current_directive = NULL;
Expand Down Expand Up @@ -145,7 +146,8 @@ void OpenACCIRConstructor::enterRoutine_directive(

void OpenACCIRConstructor::exitName(accparser::NameContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
((OpenACCRoutineDirective *)current_directive)->setName(expression);
((OpenACCRoutineDirective *)current_directive)
->setName(expression, /*is_string_literal=*/false);
}

void OpenACCIRConstructor::enterSerial_directive(
Expand Down Expand Up @@ -186,8 +188,17 @@ void OpenACCIRConstructor::enterWait_directive(

void OpenACCIRConstructor::enterAsync_clause(
accparser::Async_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_async);
if (ctx->int_expr()) {
std::string expr = trimEnclosingWhiteSpace(ctx->int_expr()->getText());
static_cast<OpenACCAsyncClause *>(current_clause)
->setAsyncExpr(expr);
static_cast<OpenACCAsyncClause *>(current_clause)
->setModifier(ACCC_ASYNC_expr);
} else {
static_cast<OpenACCAsyncClause *>(current_clause)
->setModifier(ACCC_ASYNC_unspecified);
}
}

void OpenACCIRConstructor::exitAsync_clause(
Expand All @@ -211,7 +222,10 @@ void OpenACCIRConstructor::enterBind_clause(
void OpenACCIRConstructor::exitName_or_string(
accparser::Name_or_stringContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause->addLangExpr(expression);
bool is_string =
!expression.empty() && (expression.front() == '"' || expression.front() == '\'');
static_cast<OpenACCBindClause *>(current_clause)
->setBinding(expression, is_string);
}

void OpenACCIRConstructor::exitBind_clause(accparser::Bind_clauseContext *ctx) {
Expand Down Expand Up @@ -361,12 +375,16 @@ void OpenACCIRConstructor::exitDefault_kind(

void OpenACCIRConstructor::enterDefault_async_clause(
accparser::Default_async_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_default_async);
};

void OpenACCIRConstructor::exitDefault_async_clause(
accparser::Default_async_clauseContext *ctx) {
if (ctx->int_expr()) {
std::string expr = trimEnclosingWhiteSpace(ctx->int_expr()->getText());
static_cast<OpenACCDefaultAsyncClause *>(current_clause)
->setAsyncExpr(expr);
}
((OpenACCDefaultAsyncClause *)current_clause)
->mergeClause(current_directive, current_clause);
};
Expand All @@ -386,6 +404,12 @@ void OpenACCIRConstructor::enterDevice_clause(
current_clause = current_directive->addOpenACCClause(ACCC_device);
}

void OpenACCIRConstructor::exitDevice_clause(
accparser::Device_clauseContext *ctx) {
((OpenACCDeviceClause *)current_clause)
->mergeClause(current_directive, current_clause);
}

void OpenACCIRConstructor::enterDevice_num_clause(
accparser::Device_num_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
Expand Down Expand Up @@ -426,7 +450,6 @@ void OpenACCIRConstructor::enterFirstprivate_clause(

void OpenACCIRConstructor::enterGang_clause(
accparser::Gang_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_gang);
};

Expand All @@ -453,13 +476,18 @@ void OpenACCIRConstructor::enterHost_clause(
}

void OpenACCIRConstructor::enterIf_clause(accparser::If_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_if);
}

void OpenACCIRConstructor::exitIf_clause(accparser::If_clauseContext *ctx) {
if (current_clause && current_clause->getKind() == ACCC_if) {
static_cast<OpenACCIfClause *>(current_clause)
->mergeClause(current_directive, current_clause);
}
}

void OpenACCIRConstructor::enterIf_present_clause(
accparser::If_present_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_if_present);
}

Expand Down Expand Up @@ -590,8 +618,22 @@ void OpenACCIRConstructor::enterSelf_list_clause(
current_clause = current_directive->addOpenACCClause(ACCC_self);
}

void OpenACCIRConstructor::exitSelf_list_clause(
accparser::Self_list_clauseContext *ctx) {
((OpenACCSelfClause *)current_clause)
->mergeClause(current_directive, current_clause);
}

void OpenACCIRConstructor::exitCondition(accparser::ConditionContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
if (current_clause && current_clause->getKind() == ACCC_if) {
static_cast<OpenACCIfClause *>(current_clause)->setCondition(expression);
return;
}
if (current_clause && current_clause->getKind() == ACCC_self) {
static_cast<OpenACCSelfClause *>(current_clause)->setCondition(expression);
return;
}
current_clause->addLangExpr(expression);
};

Expand All @@ -605,6 +647,12 @@ void OpenACCIRConstructor::enterTile_clause(
current_clause = current_directive->addOpenACCClause(ACCC_tile);
}

void OpenACCIRConstructor::exitTile_clause(
accparser::Tile_clauseContext *ctx) {
((OpenACCTileClause *)current_clause)
->mergeClause(current_directive, current_clause);
}

void OpenACCIRConstructor::enterUpdate_clause(
accparser::Update_clauseContext *ctx) {
current_clause = current_directive->addOpenACCClause(ACCC_update);
Expand All @@ -628,25 +676,38 @@ void OpenACCIRConstructor::exitVector_clause_modifier(

void OpenACCIRConstructor::exitVector_clause(
accparser::Vector_clauseContext *ctx) {
// If the vector clause had a length expression, capture it
if (ctx->vector_clause_args() && ctx->vector_clause_args()->int_expr()) {
std::string expr =
trimEnclosingWhiteSpace(ctx->vector_clause_args()->int_expr()->getText());
((OpenACCVectorClause *)current_clause)->setLengthExpr(expr);
if (((OpenACCVectorClause *)current_clause)->getModifier() ==
ACCC_VECTOR_unspecified) {
((OpenACCVectorClause *)current_clause)->setModifier(ACCC_VECTOR_expr_only);
}
}
((OpenACCVectorClause *)current_clause)
->mergeClause(current_directive, current_clause);
}

void OpenACCIRConstructor::enterVector_length_clause(
accparser::Vector_length_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_vector_length);
};

void OpenACCIRConstructor::exitVector_length_clause(
accparser::Vector_length_clauseContext *ctx) {
if (ctx->int_expr()) {
std::string expr =
trimEnclosingWhiteSpace(ctx->int_expr()->getText());
((OpenACCVectorLengthClause *)current_clause)->setLengthExpr(expr);
}
((OpenACCVectorLengthClause *)current_clause)
->mergeClause(current_directive, current_clause);
};

void OpenACCIRConstructor::enterVector_no_modifier_clause(
accparser::Vector_no_modifier_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_vector);
}

Expand Down Expand Up @@ -688,9 +749,9 @@ void OpenACCIRConstructor::exitWait_int_expr(
accparser::Wait_int_exprContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
if (current_directive->getKind() == ACCD_wait) {
((OpenACCWaitDirective *)current_directive)->addVar(expression);
((OpenACCWaitDirective *)current_directive)->addAsyncId(expression);
} else {
current_clause->addLangExpr(expression);
static_cast<OpenACCWaitClause *>(current_clause)->addAsyncId(expression);
}
};

Expand All @@ -708,19 +769,26 @@ void OpenACCIRConstructor::enterWorker_clause(

void OpenACCIRConstructor::exitWorker_clause_modifier(
accparser::Worker_clause_modifierContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
((OpenACCWorkerClause *)current_clause)->setModifier(ACCC_WORKER_num);
};

void OpenACCIRConstructor::exitWorker_clause(
accparser::Worker_clauseContext *ctx) {
if (ctx->worker_clause_args() && ctx->worker_clause_args()->int_expr()) {
std::string expr =
trimEnclosingWhiteSpace(ctx->worker_clause_args()->int_expr()->getText());
((OpenACCWorkerClause *)current_clause)->setNumExpr(expr);
if (((OpenACCWorkerClause *)current_clause)->getModifier() ==
ACCC_WORKER_unspecified) {
((OpenACCWorkerClause *)current_clause)->setModifier(ACCC_WORKER_expr_only);
}
}
((OpenACCWorkerClause *)current_clause)
->mergeClause(current_directive, current_clause);
}

void OpenACCIRConstructor::enterWorker_no_modifier_clause(
accparser::Worker_no_modifier_clauseContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause = current_directive->addOpenACCClause(ACCC_worker);
}

Expand All @@ -731,10 +799,58 @@ void OpenACCIRConstructor::enterWrite_clause(

void OpenACCIRConstructor::exitConst_int(accparser::Const_intContext *ctx) {
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
if (current_clause && current_clause->getKind() == ACCC_collapse) {
static_cast<OpenACCCollapseClause *>(current_clause)
->addCountExpr(expression);
return;
}
current_clause->addLangExpr(expression);
};

void OpenACCIRConstructor::exitInt_expr(accparser::Int_exprContext *ctx) {
if (current_clause) {
OpenACCClauseKind kind = current_clause->getKind();
// Vector/worker/vector_length already capture their numeric payloads
if (kind == ACCC_vector || kind == ACCC_worker ||
kind == ACCC_vector_length) {
return;
}
if (kind == ACCC_wait) {
static_cast<OpenACCWaitClause *>(current_clause)
->addAsyncId(trimEnclosingWhiteSpace(ctx->getText()));
return;
}
if (kind == ACCC_async) {
static_cast<OpenACCAsyncClause *>(current_clause)
->setAsyncExpr(trimEnclosingWhiteSpace(ctx->getText()));
static_cast<OpenACCAsyncClause *>(current_clause)
->setModifier(ACCC_ASYNC_expr);
return;
}
if (kind == ACCC_default_async) {
static_cast<OpenACCDefaultAsyncClause *>(current_clause)
->setAsyncExpr(trimEnclosingWhiteSpace(ctx->getText()));
return;
}
if (kind == ACCC_device_num) {
static_cast<OpenACCDeviceNumClause *>(current_clause)
->setDeviceExpr(trimEnclosingWhiteSpace(ctx->getText()));
return;
}
if (kind == ACCC_num_workers) {
static_cast<OpenACCNumWorkersClause *>(current_clause)
->setNumExpr(trimEnclosingWhiteSpace(ctx->getText()));
return;
}
if (kind == ACCC_num_gangs) {
static_cast<OpenACCNumGangsClause *>(current_clause)
->addNum(trimEnclosingWhiteSpace(ctx->getText()));
return;
}
if (kind == ACCC_collapse) {
return;
}
}
std::string expression = trimEnclosingWhiteSpace(ctx->getText());
current_clause->addLangExpr(expression);
};
Expand All @@ -744,6 +860,67 @@ void OpenACCIRConstructor::exitVar(accparser::VarContext *ctx) {
if (current_directive->getKind() == ACCD_cache) {
((OpenACCCacheDirective *)current_directive)->addVar(expression);
} else {
if (current_clause) {
if (current_clause->getKind() == ACCC_device_type) {
static_cast<OpenACCDeviceTypeClause *>(current_clause)
->addDeviceTypeString(expression);
return;
}
if (current_clause->getKind() == ACCC_device) {
static_cast<OpenACCDeviceClause *>(current_clause)
->addDevice(expression);
return;
}
if (current_clause->getKind() == ACCC_copy ||
current_clause->getKind() == ACCC_copyin ||
current_clause->getKind() == ACCC_copyout ||
current_clause->getKind() == ACCC_create ||
current_clause->getKind() == ACCC_no_create ||
current_clause->getKind() == ACCC_present ||
current_clause->getKind() == ACCC_link ||
current_clause->getKind() == ACCC_deviceptr ||
current_clause->getKind() == ACCC_device_resident ||
current_clause->getKind() == ACCC_attach ||
current_clause->getKind() == ACCC_use_device ||
current_clause->getKind() == ACCC_reduction ||
current_clause->getKind() == ACCC_delete ||
current_clause->getKind() == ACCC_detach ||
current_clause->getKind() == ACCC_firstprivate ||
current_clause->getKind() == ACCC_private ||
current_clause->getKind() == ACCC_host ||
current_clause->getKind() == ACCC_self) {
static_cast<OpenACCVarListClause *>(current_clause)->addVar(expression);
return;
}
if (current_clause->getKind() == ACCC_gang) {
auto *gang = static_cast<OpenACCGangClause *>(current_clause);
OpenACCGangArgKind kind = ACCC_GANG_ARG_other;
std::string value = expression;
size_t colon = expression.find(':');
if (colon != std::string::npos) {
std::string key = trimEnclosingWhiteSpace(expression.substr(0, colon));
value = trimEnclosingWhiteSpace(expression.substr(colon + 1));
std::transform(key.begin(), key.end(), key.begin(), ::tolower);
if (key == "num") {
kind = ACCC_GANG_ARG_num;
} else if (key == "dim") {
kind = ACCC_GANG_ARG_dim;
} else if (key == "static") {
kind = ACCC_GANG_ARG_static;
} else {
kind = ACCC_GANG_ARG_other;
value = expression;
}
}
gang->addArg(kind, value);
return;
}
if (current_clause->getKind() == ACCC_tile) {
static_cast<OpenACCTileClause *>(current_clause)
->addTileSize(expression);
return;
}
}
current_clause->addLangExpr(expression);
}
};
Expand Down
7 changes: 7 additions & 0 deletions src/OpenACCASTConstructor.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class OpenACCIRConstructor : public accparserBaseListener {
enterDetach_clause(accparser::Detach_clauseContext * /*ctx*/) override;
virtual void
enterDevice_clause(accparser::Device_clauseContext * /*ctx*/) override;
virtual void
exitDevice_clause(accparser::Device_clauseContext * /*ctx*/) override;
virtual void enterDevice_num_clause(
accparser::Device_num_clauseContext * /*ctx*/) override;
virtual void
Expand All @@ -157,6 +159,7 @@ class OpenACCIRConstructor : public accparserBaseListener {
virtual void
enterHost_clause(accparser::Host_clauseContext * /*ctx*/) override;
virtual void enterIf_clause(accparser::If_clauseContext * /*ctx*/) override;
virtual void exitIf_clause(accparser::If_clauseContext * /*ctx*/) override;
virtual void enterIf_present_clause(
accparser::If_present_clauseContext * /*ctx*/) override;
virtual void enterIndependent_clause(
Expand Down Expand Up @@ -191,10 +194,14 @@ class OpenACCIRConstructor : public accparserBaseListener {
exitSelf_clause(accparser::Self_clauseContext * /*ctx*/) override;
virtual void
enterSelf_list_clause(accparser::Self_list_clauseContext * /*ctx*/) override;
virtual void exitSelf_list_clause(
accparser::Self_list_clauseContext * /*ctx*/) override;
virtual void enterSeq_clause(accparser::Seq_clauseContext * /*ctx*/) override;
virtual void
enterTile_clause(accparser::Tile_clauseContext * /*ctx*/) override;
virtual void
exitTile_clause(accparser::Tile_clauseContext * /*ctx*/) override;
virtual void
enterUpdate_clause(accparser::Update_clauseContext * /*ctx*/) override;
virtual void enterUse_device_clause(
accparser::Use_device_clauseContext * /*ctx*/) override;
Expand Down
Loading