diff --git a/src/OpenACCASTConstructor.cpp b/src/OpenACCASTConstructor.cpp index 6071f1c..bfc7a98 100644 --- a/src/OpenACCASTConstructor.cpp +++ b/src/OpenACCASTConstructor.cpp @@ -2,6 +2,7 @@ #include "acclexer.h" #include "accparser.h" #include +#include #include OpenACCDirective *current_directive = NULL; @@ -145,7 +146,8 @@ void OpenACCIRConstructor::enterRoutine_directive( void OpenACCIRConstructor::exitName(accparser::NameContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); - ((OpenACCRoutineDirective *)current_directive)->setName(expression); + ((OpenACCRoutineDirective *)current_directive) + ->setName(expression, /*is_string_literal=*/false); } void OpenACCIRConstructor::enterSerial_directive( @@ -186,8 +188,17 @@ void OpenACCIRConstructor::enterWait_directive( void OpenACCIRConstructor::enterAsync_clause( accparser::Async_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_async); + if (ctx->int_expr()) { + std::string expr = trimEnclosingWhiteSpace(ctx->int_expr()->getText()); + static_cast(current_clause) + ->setAsyncExpr(expr); + static_cast(current_clause) + ->setModifier(ACCC_ASYNC_expr); + } else { + static_cast(current_clause) + ->setModifier(ACCC_ASYNC_unspecified); + } } void OpenACCIRConstructor::exitAsync_clause( @@ -211,7 +222,10 @@ void OpenACCIRConstructor::enterBind_clause( void OpenACCIRConstructor::exitName_or_string( accparser::Name_or_stringContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); - current_clause->addLangExpr(expression); + bool is_string = + !expression.empty() && (expression.front() == '"' || expression.front() == '\''); + static_cast(current_clause) + ->setBinding(expression, is_string); } void OpenACCIRConstructor::exitBind_clause(accparser::Bind_clauseContext *ctx) { @@ -361,12 +375,16 @@ void OpenACCIRConstructor::exitDefault_kind( void OpenACCIRConstructor::enterDefault_async_clause( accparser::Default_async_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_default_async); }; void OpenACCIRConstructor::exitDefault_async_clause( accparser::Default_async_clauseContext *ctx) { + if (ctx->int_expr()) { + std::string expr = trimEnclosingWhiteSpace(ctx->int_expr()->getText()); + static_cast(current_clause) + ->setAsyncExpr(expr); + } ((OpenACCDefaultAsyncClause *)current_clause) ->mergeClause(current_directive, current_clause); }; @@ -386,6 +404,12 @@ void OpenACCIRConstructor::enterDevice_clause( current_clause = current_directive->addOpenACCClause(ACCC_device); } +void OpenACCIRConstructor::exitDevice_clause( + accparser::Device_clauseContext *ctx) { + ((OpenACCDeviceClause *)current_clause) + ->mergeClause(current_directive, current_clause); +} + void OpenACCIRConstructor::enterDevice_num_clause( accparser::Device_num_clauseContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); @@ -426,7 +450,6 @@ void OpenACCIRConstructor::enterFirstprivate_clause( void OpenACCIRConstructor::enterGang_clause( accparser::Gang_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_gang); }; @@ -453,13 +476,18 @@ void OpenACCIRConstructor::enterHost_clause( } void OpenACCIRConstructor::enterIf_clause(accparser::If_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_if); } +void OpenACCIRConstructor::exitIf_clause(accparser::If_clauseContext *ctx) { + if (current_clause && current_clause->getKind() == ACCC_if) { + static_cast(current_clause) + ->mergeClause(current_directive, current_clause); + } +} + void OpenACCIRConstructor::enterIf_present_clause( accparser::If_present_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_if_present); } @@ -590,8 +618,22 @@ void OpenACCIRConstructor::enterSelf_list_clause( current_clause = current_directive->addOpenACCClause(ACCC_self); } +void OpenACCIRConstructor::exitSelf_list_clause( + accparser::Self_list_clauseContext *ctx) { + ((OpenACCSelfClause *)current_clause) + ->mergeClause(current_directive, current_clause); +} + void OpenACCIRConstructor::exitCondition(accparser::ConditionContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); + if (current_clause && current_clause->getKind() == ACCC_if) { + static_cast(current_clause)->setCondition(expression); + return; + } + if (current_clause && current_clause->getKind() == ACCC_self) { + static_cast(current_clause)->setCondition(expression); + return; + } current_clause->addLangExpr(expression); }; @@ -605,6 +647,12 @@ void OpenACCIRConstructor::enterTile_clause( current_clause = current_directive->addOpenACCClause(ACCC_tile); } +void OpenACCIRConstructor::exitTile_clause( + accparser::Tile_clauseContext *ctx) { + ((OpenACCTileClause *)current_clause) + ->mergeClause(current_directive, current_clause); +} + void OpenACCIRConstructor::enterUpdate_clause( accparser::Update_clauseContext *ctx) { current_clause = current_directive->addOpenACCClause(ACCC_update); @@ -628,25 +676,38 @@ void OpenACCIRConstructor::exitVector_clause_modifier( void OpenACCIRConstructor::exitVector_clause( accparser::Vector_clauseContext *ctx) { + // If the vector clause had a length expression, capture it + if (ctx->vector_clause_args() && ctx->vector_clause_args()->int_expr()) { + std::string expr = + trimEnclosingWhiteSpace(ctx->vector_clause_args()->int_expr()->getText()); + ((OpenACCVectorClause *)current_clause)->setLengthExpr(expr); + if (((OpenACCVectorClause *)current_clause)->getModifier() == + ACCC_VECTOR_unspecified) { + ((OpenACCVectorClause *)current_clause)->setModifier(ACCC_VECTOR_expr_only); + } + } ((OpenACCVectorClause *)current_clause) ->mergeClause(current_directive, current_clause); } void OpenACCIRConstructor::enterVector_length_clause( accparser::Vector_length_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_vector_length); }; void OpenACCIRConstructor::exitVector_length_clause( accparser::Vector_length_clauseContext *ctx) { + if (ctx->int_expr()) { + std::string expr = + trimEnclosingWhiteSpace(ctx->int_expr()->getText()); + ((OpenACCVectorLengthClause *)current_clause)->setLengthExpr(expr); + } ((OpenACCVectorLengthClause *)current_clause) ->mergeClause(current_directive, current_clause); }; void OpenACCIRConstructor::enterVector_no_modifier_clause( accparser::Vector_no_modifier_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_vector); } @@ -688,9 +749,9 @@ void OpenACCIRConstructor::exitWait_int_expr( accparser::Wait_int_exprContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); if (current_directive->getKind() == ACCD_wait) { - ((OpenACCWaitDirective *)current_directive)->addVar(expression); + ((OpenACCWaitDirective *)current_directive)->addAsyncId(expression); } else { - current_clause->addLangExpr(expression); + static_cast(current_clause)->addAsyncId(expression); } }; @@ -708,19 +769,26 @@ void OpenACCIRConstructor::enterWorker_clause( void OpenACCIRConstructor::exitWorker_clause_modifier( accparser::Worker_clause_modifierContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); ((OpenACCWorkerClause *)current_clause)->setModifier(ACCC_WORKER_num); }; void OpenACCIRConstructor::exitWorker_clause( accparser::Worker_clauseContext *ctx) { + if (ctx->worker_clause_args() && ctx->worker_clause_args()->int_expr()) { + std::string expr = + trimEnclosingWhiteSpace(ctx->worker_clause_args()->int_expr()->getText()); + ((OpenACCWorkerClause *)current_clause)->setNumExpr(expr); + if (((OpenACCWorkerClause *)current_clause)->getModifier() == + ACCC_WORKER_unspecified) { + ((OpenACCWorkerClause *)current_clause)->setModifier(ACCC_WORKER_expr_only); + } + } ((OpenACCWorkerClause *)current_clause) ->mergeClause(current_directive, current_clause); } void OpenACCIRConstructor::enterWorker_no_modifier_clause( accparser::Worker_no_modifier_clauseContext *ctx) { - std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause = current_directive->addOpenACCClause(ACCC_worker); } @@ -731,10 +799,58 @@ void OpenACCIRConstructor::enterWrite_clause( void OpenACCIRConstructor::exitConst_int(accparser::Const_intContext *ctx) { std::string expression = trimEnclosingWhiteSpace(ctx->getText()); + if (current_clause && current_clause->getKind() == ACCC_collapse) { + static_cast(current_clause) + ->addCountExpr(expression); + return; + } current_clause->addLangExpr(expression); }; void OpenACCIRConstructor::exitInt_expr(accparser::Int_exprContext *ctx) { + if (current_clause) { + OpenACCClauseKind kind = current_clause->getKind(); + // Vector/worker/vector_length already capture their numeric payloads + if (kind == ACCC_vector || kind == ACCC_worker || + kind == ACCC_vector_length) { + return; + } + if (kind == ACCC_wait) { + static_cast(current_clause) + ->addAsyncId(trimEnclosingWhiteSpace(ctx->getText())); + return; + } + if (kind == ACCC_async) { + static_cast(current_clause) + ->setAsyncExpr(trimEnclosingWhiteSpace(ctx->getText())); + static_cast(current_clause) + ->setModifier(ACCC_ASYNC_expr); + return; + } + if (kind == ACCC_default_async) { + static_cast(current_clause) + ->setAsyncExpr(trimEnclosingWhiteSpace(ctx->getText())); + return; + } + if (kind == ACCC_device_num) { + static_cast(current_clause) + ->setDeviceExpr(trimEnclosingWhiteSpace(ctx->getText())); + return; + } + if (kind == ACCC_num_workers) { + static_cast(current_clause) + ->setNumExpr(trimEnclosingWhiteSpace(ctx->getText())); + return; + } + if (kind == ACCC_num_gangs) { + static_cast(current_clause) + ->addNum(trimEnclosingWhiteSpace(ctx->getText())); + return; + } + if (kind == ACCC_collapse) { + return; + } + } std::string expression = trimEnclosingWhiteSpace(ctx->getText()); current_clause->addLangExpr(expression); }; @@ -744,6 +860,67 @@ void OpenACCIRConstructor::exitVar(accparser::VarContext *ctx) { if (current_directive->getKind() == ACCD_cache) { ((OpenACCCacheDirective *)current_directive)->addVar(expression); } else { + if (current_clause) { + if (current_clause->getKind() == ACCC_device_type) { + static_cast(current_clause) + ->addDeviceTypeString(expression); + return; + } + if (current_clause->getKind() == ACCC_device) { + static_cast(current_clause) + ->addDevice(expression); + return; + } + if (current_clause->getKind() == ACCC_copy || + current_clause->getKind() == ACCC_copyin || + current_clause->getKind() == ACCC_copyout || + current_clause->getKind() == ACCC_create || + current_clause->getKind() == ACCC_no_create || + current_clause->getKind() == ACCC_present || + current_clause->getKind() == ACCC_link || + current_clause->getKind() == ACCC_deviceptr || + current_clause->getKind() == ACCC_device_resident || + current_clause->getKind() == ACCC_attach || + current_clause->getKind() == ACCC_use_device || + current_clause->getKind() == ACCC_reduction || + current_clause->getKind() == ACCC_delete || + current_clause->getKind() == ACCC_detach || + current_clause->getKind() == ACCC_firstprivate || + current_clause->getKind() == ACCC_private || + current_clause->getKind() == ACCC_host || + current_clause->getKind() == ACCC_self) { + static_cast(current_clause)->addVar(expression); + return; + } + if (current_clause->getKind() == ACCC_gang) { + auto *gang = static_cast(current_clause); + OpenACCGangArgKind kind = ACCC_GANG_ARG_other; + std::string value = expression; + size_t colon = expression.find(':'); + if (colon != std::string::npos) { + std::string key = trimEnclosingWhiteSpace(expression.substr(0, colon)); + value = trimEnclosingWhiteSpace(expression.substr(colon + 1)); + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + if (key == "num") { + kind = ACCC_GANG_ARG_num; + } else if (key == "dim") { + kind = ACCC_GANG_ARG_dim; + } else if (key == "static") { + kind = ACCC_GANG_ARG_static; + } else { + kind = ACCC_GANG_ARG_other; + value = expression; + } + } + gang->addArg(kind, value); + return; + } + if (current_clause->getKind() == ACCC_tile) { + static_cast(current_clause) + ->addTileSize(expression); + return; + } + } current_clause->addLangExpr(expression); } }; diff --git a/src/OpenACCASTConstructor.h b/src/OpenACCASTConstructor.h index 2af4180..c5d48b0 100644 --- a/src/OpenACCASTConstructor.h +++ b/src/OpenACCASTConstructor.h @@ -132,6 +132,8 @@ class OpenACCIRConstructor : public accparserBaseListener { enterDetach_clause(accparser::Detach_clauseContext * /*ctx*/) override; virtual void enterDevice_clause(accparser::Device_clauseContext * /*ctx*/) override; + virtual void + exitDevice_clause(accparser::Device_clauseContext * /*ctx*/) override; virtual void enterDevice_num_clause( accparser::Device_num_clauseContext * /*ctx*/) override; virtual void @@ -157,6 +159,7 @@ class OpenACCIRConstructor : public accparserBaseListener { virtual void enterHost_clause(accparser::Host_clauseContext * /*ctx*/) override; virtual void enterIf_clause(accparser::If_clauseContext * /*ctx*/) override; + virtual void exitIf_clause(accparser::If_clauseContext * /*ctx*/) override; virtual void enterIf_present_clause( accparser::If_present_clauseContext * /*ctx*/) override; virtual void enterIndependent_clause( @@ -191,10 +194,14 @@ class OpenACCIRConstructor : public accparserBaseListener { exitSelf_clause(accparser::Self_clauseContext * /*ctx*/) override; virtual void enterSelf_list_clause(accparser::Self_list_clauseContext * /*ctx*/) override; + virtual void exitSelf_list_clause( + accparser::Self_list_clauseContext * /*ctx*/) override; virtual void enterSeq_clause(accparser::Seq_clauseContext * /*ctx*/) override; virtual void enterTile_clause(accparser::Tile_clauseContext * /*ctx*/) override; virtual void + exitTile_clause(accparser::Tile_clauseContext * /*ctx*/) override; + virtual void enterUpdate_clause(accparser::Update_clauseContext * /*ctx*/) override; virtual void enterUse_device_clause( accparser::Use_device_clauseContext * /*ctx*/) override; diff --git a/src/OpenACCIR.cpp b/src/OpenACCIR.cpp index 27c8a83..3d9c1f2 100644 --- a/src/OpenACCIR.cpp +++ b/src/OpenACCIR.cpp @@ -1,4 +1,5 @@ #include "OpenACCIR.h" +#include #include // Initialize static flag - default to true for backward compatibility @@ -44,32 +45,15 @@ OpenACCClause *OpenACCDirective::addOpenACCClause(int k, ...) { OpenACCClause *new_clause = NULL; switch (kind) { - case ACCC_attach: case ACCC_auto: case ACCC_capture: - case ACCC_copy: - case ACCC_delete: - case ACCC_detach: - case ACCC_device: - case ACCC_device_resident: - case ACCC_device_type: - case ACCC_deviceptr: case ACCC_finalize: - case ACCC_firstprivate: - case ACCC_host: - case ACCC_if: case ACCC_if_present: case ACCC_independent: - case ACCC_link: case ACCC_nohost: - case ACCC_no_create: - case ACCC_present: - case ACCC_private: case ACCC_read: case ACCC_seq: - case ACCC_tile: case ACCC_update: - case ACCC_use_device: case ACCC_write: { if (current_clauses->size() == 0) { new_clause = new OpenACCClause(kind); @@ -92,37 +76,23 @@ OpenACCClause *OpenACCDirective::addOpenACCClause(int k, ...) { } break; } - case ACCC_gang: { - if (this->getKind() == ACCD_routine) { + case ACCC_device_type: { + if (current_clauses->size() == 0 || !enable_clause_merging) { + new_clause = OpenACCDeviceTypeClause::addClause(this); if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(kind); current_clauses = new std::vector(); - current_clauses->push_back(new_clause); clauses[kind] = current_clauses; - } else { - /* we can have multiple clause and we merge them together now, thus we - * return the object that is already created */ - new_clause = current_clauses->at(0); } - break; + current_clauses->push_back(new_clause); } else { - if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(kind); - current_clauses = new std::vector(); - current_clauses->push_back(new_clause); - clauses[kind] = current_clauses; - } else { - new_clause = new OpenACCClause(kind); - current_clauses->push_back(new_clause); - } - break; + new_clause = current_clauses->at(0); } + break; } - - case ACCC_self: { - if (this->getKind() == ACCD_update) { + case ACCC_gang: { + if (this->getKind() == ACCD_routine) { if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(kind); + new_clause = new OpenACCGangClause(); current_clauses = new std::vector(); current_clauses->push_back(new_clause); clauses[kind] = current_clauses; @@ -134,34 +104,107 @@ OpenACCClause *OpenACCDirective::addOpenACCClause(int k, ...) { break; } else { if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(kind); + new_clause = new OpenACCGangClause(); current_clauses = new std::vector(); current_clauses->push_back(new_clause); clauses[kind] = current_clauses; } else { - new_clause = new OpenACCClause(kind); + new_clause = new OpenACCGangClause(); current_clauses->push_back(new_clause); } break; } } + + case ACCC_self: { + new_clause = OpenACCSelfClause::addClause(this); + break; + } case ACCC_async: - case ACCC_bind: - case ACCC_collapse: - case ACCC_default_async: - case ACCC_device_num: - case ACCC_num_gangs: - case ACCC_num_workers: - case ACCC_vector_length: { - if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(kind); - current_clauses = new std::vector(); - current_clauses->push_back(new_clause); - clauses[kind] = current_clauses; - } else { - new_clause = new OpenACCClause(kind); - current_clauses->push_back(new_clause); - } + new_clause = OpenACCAsyncClause::addClause(this); + break; + case ACCC_delete: { + new_clause = OpenACCDeleteClause::addClause(this); + break; + } + case ACCC_detach: { + new_clause = OpenACCDetachClause::addClause(this); + break; + } + case ACCC_collapse: { + new_clause = OpenACCCollapseClause::addClause(this); + break; + } + case ACCC_attach: { + new_clause = OpenACCAttachClause::addClause(this); + break; + } + case ACCC_copy: { + new_clause = OpenACCCopyClause::addClause(this); + break; + } + case ACCC_device: { + new_clause = OpenACCDeviceClause::addClause(this); + break; + } + case ACCC_device_resident: { + new_clause = OpenACCDeviceResidentClause::addClause(this); + break; + } + case ACCC_deviceptr: { + new_clause = OpenACCDeviceptrClause::addClause(this); + break; + } + case ACCC_default_async: { + new_clause = OpenACCDefaultAsyncClause::addClause(this); + break; + } + case ACCC_if: { + new_clause = OpenACCIfClause::addClause(this); + break; + } + case ACCC_firstprivate: { + new_clause = OpenACCFirstprivateClause::addClause(this); + break; + } + case ACCC_num_workers: { + new_clause = OpenACCNumWorkersClause::addClause(this); + break; + } + case ACCC_no_create: { + new_clause = OpenACCNoCreateClause::addClause(this); + break; + } + case ACCC_device_num: { + new_clause = OpenACCDeviceNumClause::addClause(this); + break; + } + case ACCC_present: { + new_clause = OpenACCPresentClause::addClause(this); + break; + } + case ACCC_private: { + new_clause = OpenACCPrivateClause::addClause(this); + break; + } + case ACCC_num_gangs: { + new_clause = OpenACCNumGangsClause::addClause(this); + break; + } + case ACCC_link: { + new_clause = OpenACCLinkClause::addClause(this); + break; + } + case ACCC_host: { + new_clause = OpenACCHostClause::addClause(this); + break; + } + case ACCC_tile: { + new_clause = OpenACCTileClause::addClause(this); + break; + } + case ACCC_bind: { + new_clause = OpenACCBindClause::addClause(this); break; } case ACCC_copyin: { @@ -184,10 +227,18 @@ OpenACCClause *OpenACCDirective::addOpenACCClause(int k, ...) { new_clause = OpenACCReductionClause::addClause(this); break; } + case ACCC_use_device: { + new_clause = OpenACCUseDeviceClause::addClause(this); + break; + } case ACCC_vector: { new_clause = OpenACCVectorClause::addClause(this); break; } + case ACCC_vector_length: { + new_clause = OpenACCVectorLengthClause::addClause(this); + break; + } case ACCC_wait: { new_clause = OpenACCWaitClause::addClause(this); break; @@ -221,32 +272,43 @@ void OpenACCAsyncClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_async); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } - break; - } else if (((((OpenACCClause *)(current_clause))->getExpressions()) - ->size() == 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() == 0)) { + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (incoming->getModifier() == existing->getModifier() && + incoming->getAsyncExpr() == existing->getAsyncExpr()) { current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCAsyncClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_async); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCAsyncClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_async] = current_clauses; + } else { + new_clause = new OpenACCAsyncClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + void OpenACCBindClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -257,26 +319,43 @@ void OpenACCBindClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_bind); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getBinding() == incoming->getBinding() && + existing->isStringLiteral() == incoming->isStringLiteral()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCBindClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_bind); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCBindClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_bind] = current_clauses; + } else { + new_clause = new OpenACCBindClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + void OpenACCCollapseClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -284,35 +363,183 @@ void OpenACCCollapseClause::mergeClause(OpenACCDirective *directive, return; } - std::vector *current_clauses = - directive->getClauses(ACCC_collapse); + auto *current_clauses = directive->getClauses(ACCC_collapse); + if (current_clauses->size() < 2) { + return; + } - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + + bool merged = false; + if (existing->getCounts().empty() && incoming->getCounts().empty()) { + merged = true; + } else if (!incoming->getCounts().empty()) { + merged = true; + for (const auto &count : incoming->getCounts()) { + bool found = false; + for (const auto &prev : existing->getCounts()) { + if (prev == count) { + found = true; + break; + } + } + if (!found) { + existing->addCountExpr(count); + } } - break; - } else if (((((OpenACCClause *)(current_clause))->getExpressions()) - ->size() == 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() == 0)) { + } + + if (merged) { current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +static void mergeVarList(OpenACCVarListClause *existing, + OpenACCVarListClause *incoming) { + for (const auto &var : incoming->getVars()) { + existing->addVar(var); + } +} + +void OpenACCAttachClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_attach); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCAttachClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_attach); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCAttachClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_attach] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCAttachClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCDeleteClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_delete); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCDeleteClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_delete); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDeleteClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_delete] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCDeleteClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCDetachClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_detach); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCDetachClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_detach); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDetachClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_detach] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCDetachClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCCopyClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + (void)directive; + (void)current_clause; +} + +OpenACCClause *OpenACCCopyClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_copy); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCCopyClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_copy] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCCopyClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + OpenACCClause *OpenACCCopyinClause::addClause(OpenACCDirective *directive) { std::map *> *all_clauses = @@ -349,25 +576,10 @@ void OpenACCCopyinClause::mergeClause(OpenACCDirective *directive, ((OpenACCCopyinClause *)current_clause)->getModifier() && ((OpenACCClause *)(*it))->getOriginalKeyword() == ((OpenACCClause *)current_clause)->getOriginalKeyword()) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; - } - } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); + auto *existing = static_cast(*it); + auto *incoming = static_cast(current_clause); + for (const auto &var : incoming->getVars()) { + existing->addVar(var); } current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); @@ -412,25 +624,10 @@ void OpenACCCopyoutClause::mergeClause(OpenACCDirective *directive, ((OpenACCCopyoutClause *)current_clause)->getModifier() && ((OpenACCClause *)(*it))->getOriginalKeyword() == ((OpenACCClause *)current_clause)->getOriginalKeyword()) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; - } - } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); + auto *existing = static_cast(*it); + auto *incoming = static_cast(current_clause); + for (const auto &var : incoming->getVars()) { + existing->addVar(var); } current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); @@ -475,25 +672,10 @@ void OpenACCCreateClause::mergeClause(OpenACCDirective *directive, ((OpenACCCreateClause *)current_clause)->getModifier() && ((OpenACCClause *)(*it))->getOriginalKeyword() == ((OpenACCClause *)current_clause)->getOriginalKeyword()) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; - } - } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); + auto *existing = static_cast(*it); + auto *incoming = static_cast(current_clause); + for (const auto &var : incoming->getVars()) { + existing->addVar(var); } current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); @@ -502,57 +684,507 @@ void OpenACCCreateClause::mergeClause(OpenACCDirective *directive, } }; -OpenACCClause *OpenACCDefaultClause::addClause(OpenACCDirective *directive) { +OpenACCClause *OpenACCNoCreateClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_no_create); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCNoCreateClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_no_create] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCNoCreateClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} - std::map *> *all_clauses = - directive->getAllClauses(); - std::vector *current_clauses = - directive->getClauses(ACCC_default); - OpenACCClause *new_clause = NULL; +OpenACCClause *OpenACCPresentClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_present); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCPresentClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_present] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCPresentClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} +OpenACCClause *OpenACCLinkClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_link); + OpenACCClause *new_clause = nullptr; if (current_clauses->size() == 0) { - new_clause = new OpenACCDefaultClause(); + new_clause = new OpenACCLinkClause(); current_clauses = new std::vector(); current_clauses->push_back(new_clause); - (*all_clauses)[ACCC_default] = current_clauses; - } else { /* could be an error since default clause may only appear once */ - std::cerr << "Cannot have two default clause for the directive " - << directive->getKind() << ", ignored\n"; - }; + (*all_clauses)[ACCC_link] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCLinkClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} +OpenACCClause * +OpenACCDeviceResidentClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_device_resident); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDeviceResidentClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_device_resident] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCDeviceResidentClause(); + current_clauses->push_back(new_clause); + } return new_clause; -}; +} -void OpenACCDefaultAsyncClause::mergeClause(OpenACCDirective *directive, +void OpenACCFirstprivateClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { - // Respect the global clause merging flag if (!OpenACCDirective::getClauseMerging()) { return; } - - std::vector *current_clauses = - directive->getClauses(ACCC_default_async); - - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + auto *current_clauses = directive->getClauses(ACCC_firstprivate); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = static_cast( + current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause * +OpenACCFirstprivateClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_firstprivate); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCFirstprivateClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_firstprivate] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCFirstprivateClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCHostClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_host); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCHostClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_host); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCHostClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_host] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCHostClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCPrivateClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_private); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCPrivateClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_private); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCPrivateClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_private] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCPrivateClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +OpenACCClause *OpenACCSelfClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_self); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + current_clauses = new std::vector(); + (*all_clauses)[ACCC_self] = current_clauses; + } + new_clause = new OpenACCSelfClause(); + current_clauses->push_back(new_clause); + return new_clause; +} + +OpenACCClause *OpenACCDeviceptrClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_deviceptr); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDeviceptrClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_deviceptr] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCDeviceptrClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +OpenACCClause *OpenACCUseDeviceClause::addClause(OpenACCDirective *directive) { + auto *all_clauses = directive->getAllClauses(); + auto *current_clauses = directive->getClauses(ACCC_use_device); + OpenACCClause *new_clause = nullptr; + if (current_clauses->size() == 0) { + new_clause = new OpenACCUseDeviceClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_use_device] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCUseDeviceClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCNoCreateClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_no_create); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +void OpenACCPresentClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_present); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +void OpenACCLinkClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_link); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +void OpenACCDeviceResidentClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_device_resident); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = static_cast( + current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +void OpenACCDeviceptrClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_deviceptr); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +void OpenACCUseDeviceClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + if (!OpenACCDirective::getClauseMerging()) { + return; + } + auto *current_clauses = directive->getClauses(ACCC_use_device); + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + auto *existing = + static_cast(current_clauses->front()); + mergeVarList(existing, incoming); + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCDefaultClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_default); + OpenACCClause *new_clause = NULL; + + if (current_clauses->size() == 0) { + new_clause = new OpenACCDefaultClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_default] = current_clauses; + } else { /* could be an error since default clause may only appear once */ + std::cerr << "Cannot have two default clause for the directive " + << directive->getKind() << ", ignored\n"; + }; + + return new_clause; +}; + +void OpenACCDefaultAsyncClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + // Respect the global clause merging flag + if (!OpenACCDirective::getClauseMerging()) { + return; + } + + std::vector *current_clauses = + directive->getClauses(ACCC_default_async); + + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (incoming->getAsyncExpr() == existing->getAsyncExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCDefaultAsyncClause::addClause( + OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_default_async); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDefaultAsyncClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_default_async] = current_clauses; + } else { + new_clause = new OpenACCDefaultAsyncClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCDeviceClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + // Respect the global clause merging flag + if (!OpenACCDirective::getClauseMerging()) { + return; + } + + auto *current_clauses = directive->getClauses(ACCC_device); + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + auto *existing = static_cast(current_clauses->front()); + + for (const auto &dev : incoming->getDevices()) { + if (std::find(existing->getDevices().begin(), + existing->getDevices().end(), + dev) == existing->getDevices().end()) { + existing->addDevice(dev); + } + } + + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCDeviceClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_device); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDeviceClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_device] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCDeviceClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCIfClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + // Respect the global clause merging flag + if (!OpenACCDirective::getClauseMerging()) { + return; + } + + auto *current_clauses = directive->getClauses(ACCC_if); + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + auto *existing = static_cast(current_clauses->front()); + + if (!incoming->getCondition().empty() && existing->getCondition().empty()) { + existing->setCondition(incoming->getCondition()); + } + + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; +} + +OpenACCClause *OpenACCIfClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_if); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCIfClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_if] = current_clauses; + } else if (OpenACCDirective::getClauseMerging()) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCIfClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + void OpenACCDeviceNumClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -565,24 +1197,120 @@ void OpenACCDeviceNumClause::mergeClause(OpenACCDirective *directive, for (std::vector::iterator it = current_clauses->begin(); it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + auto *existing = static_cast(*it); + auto *incoming = static_cast(current_clause); + if (existing->getDeviceExpr() == incoming->getDeviceExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCDeviceNumClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_device_num); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCDeviceNumClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_device_num] = current_clauses; + } else { + new_clause = new OpenACCDeviceNumClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +OpenACCClause *OpenACCTileClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_tile); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCTileClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_tile] = current_clauses; + } else { + new_clause = new OpenACCTileClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCTileClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *current_clause) { + // Respect the global clause merging flag + if (!OpenACCDirective::getClauseMerging()) { + return; + } + + auto *current_clauses = directive->getClauses(ACCC_tile); + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + for (const auto &size : incoming->getTileSizes()) { + existing->addTileSize(size); + } + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; + } +} + +OpenACCClause *OpenACCCollapseClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_collapse); + OpenACCClause *new_clause = NULL; + if (current_clauses->size() == 0) { + new_clause = new OpenACCCollapseClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_collapse] = current_clauses; + } else { + new_clause = new OpenACCCollapseClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + +OpenACCClause *OpenACCGangClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_gang); + OpenACCGangClause *new_clause = nullptr; + if (directive->getKind() == ACCD_routine && current_clauses->size() != 0) { + new_clause = static_cast(current_clauses->at(0)); + } else { + new_clause = new OpenACCGangClause(); + if (current_clauses->size() == 0) { + current_clauses = new std::vector(); + (*all_clauses)[ACCC_gang] = current_clauses; + } + current_clauses->push_back(new_clause); + } + return new_clause; +} + void OpenACCGangClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -593,39 +1321,42 @@ void OpenACCGangClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_gang); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() == - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() == 0)) { + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + + bool incoming_empty = incoming->getArgs().empty(); + bool existing_empty = existing->getArgs().empty(); + + // Merge empty clauses + if (incoming_empty && existing_empty) { current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; - } else if (((((OpenACCClause *)(current_clause))->getExpressions()) - ->size() != 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; + } + + if (!incoming_empty && !existing_empty) { + for (const auto &arg : incoming->getArgs()) { + bool found = false; + for (const auto &prev : existing->getArgs()) { + if (prev.kind == arg.kind && prev.value == arg.value) { + found = true; + break; } } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); + if (!found) { + existing->addArg(arg.kind, arg.value); + } } current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } @@ -641,26 +1372,43 @@ void OpenACCNumGangsClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_num_gangs); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getNums() == incoming->getNums()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCNumGangsClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_num_gangs); + OpenACCClause *new_clause = NULL; + + if (current_clauses->size() == 0) { + new_clause = new OpenACCNumGangsClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_num_gangs] = current_clauses; + } else { + new_clause = new OpenACCNumGangsClause(); + current_clauses->push_back(new_clause); + } + return new_clause; +} + void OpenACCNumWorkersClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -671,26 +1419,44 @@ void OpenACCNumWorkersClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_num_workers); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getNumExpr() == incoming->getNumExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } }; +OpenACCClause *OpenACCNumWorkersClause::addClause(OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_num_workers); + OpenACCClause *new_clause = NULL; + if (directive->getKind() == ACCD_routine && current_clauses->size() != 0) { + new_clause = current_clauses->at(0); + } else { + new_clause = new OpenACCNumWorkersClause(); + if (current_clauses->size() == 0) { + current_clauses = new std::vector(); + (*all_clauses)[ACCC_num_workers] = current_clauses; + } + current_clauses->push_back(new_clause); + } + + return new_clause; +} + OpenACCClause *OpenACCReductionClause::addClause(OpenACCDirective *directive) { std::map *> *all_clauses = @@ -725,26 +1491,9 @@ void OpenACCReductionClause::mergeClause(OpenACCDirective *directive, it != current_clauses->end() - 1; it++) { if (((OpenACCReductionClause *)(*it))->getOperator() == ((OpenACCReductionClause *)current_clause)->getOperator()) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; - } - } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); - } + auto *existing = static_cast(*it); + auto *incoming = static_cast(current_clause); + mergeVarList(existing, incoming); current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); break; @@ -754,35 +1503,51 @@ void OpenACCReductionClause::mergeClause(OpenACCDirective *directive, void OpenACCSelfClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { - // Respect the global clause merging flag if (!OpenACCDirective::getClauseMerging()) { return; } - std::vector *current_clauses = - directive->getClauses(ACCC_self); + auto *current_clauses = directive->getClauses(ACCC_self); + if (current_clauses->size() < 2) { + return; + } - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { + auto *incoming = static_cast(current_clause); + const bool incoming_has_condition = !incoming->getCondition().empty(); + + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + const bool existing_has_condition = !existing->getCondition().empty(); + + if (incoming_has_condition && existing_has_condition) { + if (incoming->getCondition() == existing->getCondition()) { current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; + } + continue; + } + + // Merge self clauses that carry variable lists or are empty. + if (!incoming_has_condition && !existing_has_condition) { + const auto &incoming_vars = incoming->getVars(); + const auto &existing_vars = existing->getVars(); + const bool both_empty = incoming_vars.empty() && existing_vars.empty(); + if (both_empty) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; + } + + for (const auto &var : incoming_vars) { + existing->addVar(var); } - break; - } else if (((((OpenACCClause *)(current_clause))->getExpressions()) - ->size() == 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() == 0)) { current_clauses->pop_back(); directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } @@ -797,7 +1562,7 @@ OpenACCClause *OpenACCVectorClause::addClause(OpenACCDirective *directive) { OpenACCClause *new_clause = NULL; if (directive->getKind() == ACCD_routine) { if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(ACCC_vector); + new_clause = new OpenACCVectorClause(); current_clauses = new std::vector(); current_clauses->push_back(new_clause); (*all_clauses)[ACCC_vector] = current_clauses; @@ -828,29 +1593,55 @@ void OpenACCVectorClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_vector); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((OpenACCVectorClause *)(*it))->getModifier() == - ((OpenACCVectorClause *)current_clause)->getModifier()) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } - break; - } + if (current_clauses->size() < 2) { + return; + } + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getModifier() == incoming->getModifier() && + existing->getLengthExpr() == incoming->getLengthExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; } } }; +OpenACCClause *OpenACCVectorLengthClause::addClause( + OpenACCDirective *directive) { + + std::map *> *all_clauses = + directive->getAllClauses(); + std::vector *current_clauses = + directive->getClauses(ACCC_vector_length); + OpenACCClause *new_clause = NULL; + if (directive->getKind() == ACCD_routine) { + if (current_clauses->size() == 0) { + new_clause = new OpenACCVectorLengthClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_vector_length] = current_clauses; + } else { + new_clause = current_clauses->at(0); + } + } else { + if (current_clauses->size() == 0) { + new_clause = new OpenACCVectorLengthClause(); + current_clauses = new std::vector(); + current_clauses->push_back(new_clause); + (*all_clauses)[ACCC_vector_length] = current_clauses; + } else { + new_clause = new OpenACCVectorLengthClause(); + current_clauses->push_back(new_clause); + }; + } + + return new_clause; +} + void OpenACCVectorLengthClause::mergeClause(OpenACCDirective *directive, OpenACCClause *current_clause) { // Respect the global clause merging flag @@ -861,21 +1652,18 @@ void OpenACCVectorLengthClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_vector_length); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getLengthExpr() == incoming->getLengthExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; break; } } @@ -909,47 +1697,44 @@ void OpenACCWaitClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_wait); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0) && - ((OpenACCWaitClause *)current_clause)->getDevnum() == - ((OpenACCWaitClause *)(*it))->getDevnum() && - ((OpenACCWaitClause *)current_clause)->getQueues() == - ((OpenACCWaitClause *)(*it))->getQueues()) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - for (std::vector::iterator it_expr_current = - expressions_current_clause->begin(); - it_expr_current != expressions_current_clause->end(); - it_expr_current++) { - bool para_merge = true; - for (std::vector::iterator it_expr_previous = - expressions_previous_clause->begin(); - it_expr_previous != expressions_previous_clause->end(); - it_expr_previous++) { - if (*it_expr_current == *it_expr_previous) { - para_merge = false; + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (incoming->getDevnum() == existing->getDevnum() && + incoming->getQueues() == existing->getQueues()) { + if (incoming->getAsyncIds().empty() && existing->getAsyncIds().empty()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; + } + if (!incoming->getAsyncIds().empty() && + !existing->getAsyncIds().empty()) { + for (const auto &id : incoming->getAsyncIds()) { + bool found = false; + for (const auto &prev : existing->getAsyncIds()) { + if (prev == id) { + found = true; + break; + } + } + if (!found) { + existing->addAsyncId(id); + } } + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; } - if (para_merge == true) - expressions_previous_clause->push_back(*it_expr_current); } - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - break; - } else if (((((OpenACCClause *)(current_clause))->getExpressions()) - ->size() == 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() == 0)) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - break; } - } -}; + }; OpenACCClause *OpenACCWorkerClause::addClause(OpenACCDirective *directive) { @@ -960,7 +1745,7 @@ OpenACCClause *OpenACCWorkerClause::addClause(OpenACCDirective *directive) { OpenACCClause *new_clause = NULL; if (directive->getKind() == ACCD_routine) { if (current_clauses->size() == 0) { - new_clause = new OpenACCClause(ACCC_worker); + new_clause = new OpenACCWorkerClause(); current_clauses = new std::vector(); current_clauses->push_back(new_clause); (*all_clauses)[ACCC_worker] = current_clauses; @@ -991,25 +1776,110 @@ void OpenACCWorkerClause::mergeClause(OpenACCDirective *directive, std::vector *current_clauses = directive->getClauses(ACCC_worker); - for (std::vector::iterator it = current_clauses->begin(); - it != current_clauses->end() - 1; it++) { - if (((OpenACCWorkerClause *)(*it))->getModifier() == - ((OpenACCWorkerClause *)current_clause)->getModifier()) { - if (((((OpenACCClause *)(current_clause))->getExpressions())->size() != - 0) && - ((((OpenACCClause *)(*it))->getExpressions())->size() != 0)) { - std::vector *expressions_previous_clause = - ((OpenACCClause *)(*it))->getExpressions(); - std::vector *expressions_current_clause = - ((OpenACCClause *)(current_clause))->getExpressions(); - std::string new_expression = expressions_current_clause->at(0); - std::string old_expression = expressions_previous_clause->at(0); - if (new_expression == old_expression) { - current_clauses->pop_back(); - directive->getClausesInOriginalOrder()->pop_back(); - } - break; - } + if (current_clauses->size() < 2) { + return; + } + + auto *incoming = static_cast(current_clause); + for (auto it = current_clauses->begin(); it != current_clauses->end() - 1; + ++it) { + auto *existing = static_cast(*it); + if (existing->getModifier() == incoming->getModifier() && + existing->getNumExpr() == incoming->getNumExpr()) { + current_clauses->pop_back(); + directive->getClausesInOriginalOrder()->pop_back(); + delete incoming; + break; } } }; + +static OpenACCDeviceTypeKind parseDeviceTypeKind(const std::string &value) { + std::string lower = value; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + if (lower == "host") { + return ACCC_DEVICE_TYPE_host; + } + if (lower == "any") { + return ACCC_DEVICE_TYPE_any; + } + if (lower == "multicore") { + return ACCC_DEVICE_TYPE_multicore; + } + if (lower == "default") { + return ACCC_DEVICE_TYPE_default; + } + return ACCC_DEVICE_TYPE_unknown; +} + +void OpenACCDeviceTypeClause::addDeviceType(OpenACCDeviceTypeKind kind) { + if (kind == ACCC_DEVICE_TYPE_unknown) { + return; + } + for (const auto &existing : device_types) { + if (existing == kind) { + return; + } + } + device_types.push_back(kind); +} + +void OpenACCDeviceTypeClause::addDeviceTypeString(const std::string &value) { + OpenACCDeviceTypeKind kind = parseDeviceTypeKind(value); + if (kind == ACCC_DEVICE_TYPE_unknown) { + if (std::find(unknown_types.begin(), unknown_types.end(), value) == + unknown_types.end()) { + unknown_types.push_back(value); + } + } else { + addDeviceType(kind); + } +} + +OpenACCClause *OpenACCDeviceTypeClause::addClause(OpenACCDirective *directive) { + auto *new_clause = new OpenACCDeviceTypeClause(); + std::vector *current_clauses = + directive->getClauses(ACCC_device_type); + if (current_clauses->size() == 0) { + current_clauses->push_back(new_clause); + } else if (directive->getClauseMerging()) { + delete new_clause; + new_clause = + dynamic_cast(current_clauses->at(0)); + if (new_clause == nullptr) { + new_clause = new OpenACCDeviceTypeClause(); + current_clauses->push_back(new_clause); + } + } else { + current_clauses->push_back(new_clause); + } + return new_clause; +} + +void OpenACCDeviceTypeClause::mergeClause(OpenACCDirective *directive, + OpenACCClause *merge_clause) { + if (!directive->getClauseMerging()) { + return; + } + auto *other = dynamic_cast(merge_clause); + if (other == nullptr) { + return; + } + for (auto k : other->device_types) { + addDeviceType(k); + } + for (const auto &raw : other->unknown_types) { + if (std::find(unknown_types.begin(), unknown_types.end(), raw) == + unknown_types.end()) { + unknown_types.push_back(raw); + } + } + auto *current_clauses = directive->getClauses(ACCC_device_type); + if (!current_clauses->empty() && + current_clauses->back() == merge_clause && + current_clauses->size() > 1) { + current_clauses->pop_back(); + } + directive->getClausesInOriginalOrder()->pop_back(); + delete merge_clause; +} diff --git a/src/OpenACCIR.h b/src/OpenACCIR.h index b24b769..0bcd125 100644 --- a/src/OpenACCIR.h +++ b/src/OpenACCIR.h @@ -1,3 +1,4 @@ +#include #include #include #include @@ -83,6 +84,35 @@ class OpenACCClause : public ACC_SourceLocation { std::string expressionToString(); }; +// Common base for clauses that carry a variable list. +class OpenACCVarListClause : public OpenACCClause { +protected: + std::vector vars; + +public: + OpenACCVarListClause(OpenACCClauseKind k, int _line = 0, int _col = 0) + : OpenACCClause(k, _line, _col) {} + + void addVar(const std::string &expr) { + if (std::find(vars.begin(), vars.end(), expr) == vars.end()) { + vars.push_back(expr); + } + } + + const std::vector &getVars() const { return vars; } + + std::string varsToString() const { + std::string out; + for (auto it = vars.begin(); it != vars.end(); ++it) { + out += *it; + if (it + 1 != vars.end()) { + out += ", "; + } + } + return out; + } +}; + /** * The class for all the OpenACC directives */ @@ -207,7 +237,7 @@ class OpenACCDirective : public ACC_SourceLocation { class OpenACCCacheDirective : public OpenACCDirective { protected: OpenACCCacheDirectiveModifier modifier = ACCC_CACHE_unspecified; - std::vector expressions; + std::vector vars; public: OpenACCCacheDirective() : OpenACCDirective(ACCD_cache){}; @@ -215,10 +245,23 @@ class OpenACCCacheDirective : public OpenACCDirective { void setModifier(OpenACCCacheDirectiveModifier _modifier) { modifier = _modifier; }; - std::vector *getExpressions() { return &expressions; }; - void addVar(std::string _string) { expressions.push_back(_string); }; + const std::vector &getVars() const { return vars; } + void addVar(const std::string &_string) { + if (std::find(vars.begin(), vars.end(), _string) == vars.end()) { + vars.push_back(_string); + } + }; + std::string varsToString() const { + std::string out; + for (auto it = vars.begin(); it != vars.end(); ++it) { + out += *it; + if (it + 1 != vars.end()) { + out += ", "; + } + } + return out; + } std::string toString(); - std::string expressionToString(); }; // End directive @@ -238,17 +281,22 @@ class OpenACCEndDirective : public OpenACCDirective { class OpenACCRoutineDirective : public OpenACCDirective { protected: std::string name = ""; + bool name_is_string_literal = false; public: OpenACCRoutineDirective() : OpenACCDirective(ACCD_routine){}; - void setName(std::string _name) { name = _name; }; + void setName(std::string _name, bool is_string_literal = false) { + name = _name; + name_is_string_literal = is_string_literal; + }; std::string getName() { return name; }; + bool isNameStringLiteral() const { return name_is_string_literal; }; }; // Wait directive class OpenACCWaitDirective : public OpenACCDirective { protected: - std::vector expressions; + std::vector async_ids; std::string devnum = ""; bool queues = false; @@ -258,8 +306,8 @@ class OpenACCWaitDirective : public OpenACCDirective { std::string getDevnum() { return devnum; }; void setQueues(bool _queues) { queues = _queues; }; bool getQueues() { return queues; }; - std::vector *getExpressions() { return &expressions; }; - void addVar(std::string _string) { expressions.push_back(_string); }; + std::vector *getAsyncIds() { return &async_ids; }; + void addAsyncId(std::string _string) { async_ids.push_back(_string); }; std::string toString(); std::string expressionToString(); }; @@ -267,9 +315,18 @@ class OpenACCWaitDirective : public OpenACCDirective { // Async Clause class OpenACCAsyncClause : public OpenACCClause { +protected: + OpenACCAsyncModifier modifier = ACCC_ASYNC_unspecified; + std::string async_expr; + public: OpenACCAsyncClause() : OpenACCClause(ACCC_async){}; + void setModifier(OpenACCAsyncModifier m) { modifier = m; } + OpenACCAsyncModifier getModifier() const { return modifier; } + void setAsyncExpr(const std::string &expr) { async_expr = expr; } + const std::string &getAsyncExpr() const { return async_expr; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); @@ -278,9 +335,20 @@ class OpenACCAsyncClause : public OpenACCClause { // Bind Clause class OpenACCBindClause : public OpenACCClause { +protected: + std::string binding = ""; + bool is_string_literal = false; + public: OpenACCBindClause() : OpenACCClause(ACCC_bind){}; + void setBinding(const std::string &_binding, bool _is_string_literal) { + binding = _binding; + is_string_literal = _is_string_literal; + } + const std::string &getBinding() const { return binding; } + bool isStringLiteral() const { return is_string_literal; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); @@ -289,22 +357,39 @@ class OpenACCBindClause : public OpenACCClause { // Collapse Clause class OpenACCCollapseClause : public OpenACCClause { +protected: + std::vector counts; + public: OpenACCCollapseClause() : OpenACCClause(ACCC_collapse){}; + void addCountExpr(const std::string &expr) { counts.push_back(expr); } + const std::vector &getCounts() const { return counts; } + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Copy Clause +class OpenACCCopyClause : public OpenACCVarListClause { + +public: + OpenACCCopyClause() : OpenACCVarListClause(ACCC_copy) {} + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); }; // Copyin Clause -class OpenACCCopyinClause : public OpenACCClause { +class OpenACCCopyinClause : public OpenACCVarListClause { protected: OpenACCCopyinClauseModifier modifier = ACCC_COPYIN_unspecified; public: - OpenACCCopyinClause() : OpenACCClause(ACCC_copyin){}; + OpenACCCopyinClause() : OpenACCVarListClause(ACCC_copyin){}; OpenACCCopyinClauseModifier getModifier() { return modifier; }; @@ -318,13 +403,13 @@ class OpenACCCopyinClause : public OpenACCClause { }; // Copyout Clause -class OpenACCCopyoutClause : public OpenACCClause { +class OpenACCCopyoutClause : public OpenACCVarListClause { protected: OpenACCCopyoutClauseModifier modifier = ACCC_COPYOUT_unspecified; public: - OpenACCCopyoutClause() : OpenACCClause(ACCC_copyout){}; + OpenACCCopyoutClause() : OpenACCVarListClause(ACCC_copyout){}; OpenACCCopyoutClauseModifier getModifier() { return modifier; }; @@ -338,13 +423,13 @@ class OpenACCCopyoutClause : public OpenACCClause { }; // Create Clause -class OpenACCCreateClause : public OpenACCClause { +class OpenACCCreateClause : public OpenACCVarListClause { protected: OpenACCCreateClauseModifier modifier = ACCC_CREATE_unspecified; public: - OpenACCCreateClause() : OpenACCClause(ACCC_create){}; + OpenACCCreateClause() : OpenACCVarListClause(ACCC_create){}; OpenACCCreateClauseModifier getModifier() { return modifier; }; @@ -357,6 +442,67 @@ class OpenACCCreateClause : public OpenACCClause { void mergeClause(OpenACCDirective *, OpenACCClause *); }; +// No_create Clause +class OpenACCNoCreateClause : public OpenACCVarListClause { +public: + OpenACCNoCreateClause() : OpenACCVarListClause(ACCC_no_create) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Present Clause +class OpenACCPresentClause : public OpenACCVarListClause { +public: + OpenACCPresentClause() : OpenACCVarListClause(ACCC_present) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Link Clause +class OpenACCLinkClause : public OpenACCVarListClause { +public: + OpenACCLinkClause() : OpenACCVarListClause(ACCC_link) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Device Resident Clause +class OpenACCDeviceResidentClause : public OpenACCVarListClause { +public: + OpenACCDeviceResidentClause() + : OpenACCVarListClause(ACCC_device_resident) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Deviceptr Clause +class OpenACCDeviceptrClause : public OpenACCVarListClause { +public: + OpenACCDeviceptrClause() : OpenACCVarListClause(ACCC_deviceptr) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Use_device Clause +class OpenACCUseDeviceClause : public OpenACCVarListClause { +public: + OpenACCUseDeviceClause() : OpenACCVarListClause(ACCC_use_device) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + // Default Clause class OpenACCDefaultClause : public OpenACCClause { @@ -379,19 +525,112 @@ class OpenACCDefaultClause : public OpenACCClause { // Default_async Clause class OpenACCDefaultAsyncClause : public OpenACCClause { +protected: + std::string async_expr; + public: OpenACCDefaultAsyncClause() : OpenACCClause(ACCC_default_async){}; + void setAsyncExpr(const std::string &expr) { async_expr = expr; } + const std::string &getAsyncExpr() const { return async_expr; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); }; -// Device_num Clause -class OpenACCDeviceNumClause : public OpenACCClause { +// Attach Clause +class OpenACCAttachClause : public OpenACCVarListClause { +public: + OpenACCAttachClause() : OpenACCVarListClause(ACCC_attach) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Delete Clause +class OpenACCDeleteClause : public OpenACCVarListClause { +public: + OpenACCDeleteClause() : OpenACCVarListClause(ACCC_delete) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Detach Clause +class OpenACCDetachClause : public OpenACCVarListClause { +public: + OpenACCDetachClause() : OpenACCVarListClause(ACCC_detach) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Device Clause +class OpenACCDeviceClause : public OpenACCClause { + +protected: + std::vector devices; public: - OpenACCDeviceNumClause() : OpenACCClause(ACCC_device_num){}; + OpenACCDeviceClause() : OpenACCClause(ACCC_device) {} + + void addDevice(const std::string &expr) { + if (std::find(devices.begin(), devices.end(), expr) == devices.end()) { + devices.push_back(expr); + } + } + const std::vector &getDevices() const { return devices; } + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Firstprivate Clause +class OpenACCFirstprivateClause : public OpenACCVarListClause { +public: + OpenACCFirstprivateClause() : OpenACCVarListClause(ACCC_firstprivate) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Host Clause +class OpenACCHostClause : public OpenACCVarListClause { +public: + OpenACCHostClause() : OpenACCVarListClause(ACCC_host) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Private Clause +class OpenACCPrivateClause : public OpenACCVarListClause { +public: + OpenACCPrivateClause() : OpenACCVarListClause(ACCC_private) {} + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// If Clause +class OpenACCIfClause : public OpenACCClause { + +protected: + std::string condition; + +public: + OpenACCIfClause() : OpenACCClause(ACCC_if) {} + + void setCondition(const std::string &expr) { condition = expr; } + const std::string &getCondition() const { return condition; } static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); @@ -401,9 +640,23 @@ class OpenACCDeviceNumClause : public OpenACCClause { // Gang Clause class OpenACCGangClause : public OpenACCClause { +public: + struct GangArg { + OpenACCGangArgKind kind; + std::string value; + }; + +protected: + std::vector args; + public: OpenACCGangClause() : OpenACCClause(ACCC_gang){}; + void addArg(OpenACCGangArgKind kind, const std::string &value) { + args.push_back({kind, value}); + } + const std::vector &getArgs() const { return args; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); @@ -411,10 +664,15 @@ class OpenACCGangClause : public OpenACCClause { // Num_gangs Clause class OpenACCNumGangsClause : public OpenACCClause { +protected: + std::vector nums; public: OpenACCNumGangsClause() : OpenACCClause(ACCC_num_gangs){}; + void addNum(const std::string &expr) { nums.push_back(expr); } + const std::vector &getNums() const { return nums; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); @@ -423,23 +681,61 @@ class OpenACCNumGangsClause : public OpenACCClause { // Num_workers Clause class OpenACCNumWorkersClause : public OpenACCClause { +protected: + std::string num_expr; + public: OpenACCNumWorkersClause() : OpenACCClause(ACCC_num_workers){}; + void setNumExpr(const std::string &expr) { num_expr = expr; } + const std::string &getNumExpr() const { return num_expr; } + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Device_num Clause +class OpenACCDeviceNumClause : public OpenACCClause { +protected: + std::string device_expr; + +public: + OpenACCDeviceNumClause() : OpenACCClause(ACCC_device_num) {} + + void setDeviceExpr(const std::string &expr) { device_expr = expr; } + const std::string &getDeviceExpr() const { return device_expr; } + + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + +// Tile Clause +class OpenACCTileClause : public OpenACCClause { +protected: + std::vector tile_sizes; + +public: + OpenACCTileClause() : OpenACCClause(ACCC_tile) {} + + void addTileSize(const std::string &expr) { tile_sizes.push_back(expr); } + const std::vector &getTileSizes() const { return tile_sizes; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); }; // Reduction Clause -class OpenACCReductionClause : public OpenACCClause { +class OpenACCReductionClause : public OpenACCVarListClause { protected: OpenACCReductionClauseOperator reduction_operator = ACCC_REDUCTION_unspecified; public: - OpenACCReductionClause() : OpenACCClause(ACCC_reduction){}; + OpenACCReductionClause() : OpenACCVarListClause(ACCC_reduction){}; OpenACCReductionClauseOperator getOperator() { return reduction_operator; }; @@ -452,11 +748,17 @@ class OpenACCReductionClause : public OpenACCClause { void mergeClause(OpenACCDirective *, OpenACCClause *); }; -// Self Clause -class OpenACCSelfClause : public OpenACCClause { +// Self Clause: supports either a condition or a variable list. +class OpenACCSelfClause : public OpenACCVarListClause { + +protected: + std::string condition; public: - OpenACCSelfClause() : OpenACCClause(ACCC_self){}; + OpenACCSelfClause() : OpenACCVarListClause(ACCC_self) {} + + void setCondition(const std::string &expr) { condition = expr; } + const std::string &getCondition() const { return condition; } static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); @@ -465,18 +767,20 @@ class OpenACCSelfClause : public OpenACCClause { // Vector Clause class OpenACCVectorClause : public OpenACCClause { - protected: OpenACCVectorClauseModifier modifier = ACCC_VECTOR_unspecified; + std::string length_expr; public: OpenACCVectorClause() : OpenACCClause(ACCC_vector){}; - OpenACCVectorClauseModifier getModifier() { return modifier; }; + OpenACCVectorClauseModifier getModifier() const { return modifier; }; void setModifier(OpenACCVectorClauseModifier _modifier) { modifier = _modifier; }; + void setLengthExpr(const std::string &expr) { length_expr = expr; } + const std::string &getLengthExpr() const { return length_expr; } static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); @@ -486,9 +790,15 @@ class OpenACCVectorClause : public OpenACCClause { // Vector_length Clause class OpenACCVectorLengthClause : public OpenACCClause { +protected: + std::string length_expr; + public: OpenACCVectorLengthClause() : OpenACCClause(ACCC_vector_length){}; + void setLengthExpr(const std::string &expr) { length_expr = expr; } + const std::string &getLengthExpr() const { return length_expr; } + static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); @@ -500,6 +810,7 @@ class OpenACCWaitClause : public OpenACCClause { protected: std::string devnum = ""; bool queues = false; + std::vector async_ids; public: OpenACCWaitClause() : OpenACCClause(ACCC_wait){}; @@ -508,26 +819,53 @@ class OpenACCWaitClause : public OpenACCClause { std::string getDevnum() { return devnum; }; void setQueues(bool _queues) { queues = _queues; }; bool getQueues() { return queues; }; + void addAsyncId(const std::string &expr) { async_ids.push_back(expr); } + const std::vector &getAsyncIds() const { return async_ids; } static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); void mergeClause(OpenACCDirective *, OpenACCClause *); }; +// Device_type Clause +class OpenACCDeviceTypeClause : public OpenACCClause { +protected: + std::vector device_types; + std::vector unknown_types; + +public: + OpenACCDeviceTypeClause() : OpenACCClause(ACCC_device_type) {} + + void addDeviceType(OpenACCDeviceTypeKind kind); + void addDeviceTypeString(const std::string &value); + const std::vector &getDeviceTypes() const { + return device_types; + } + const std::vector &getUnknownDeviceTypes() const { + return unknown_types; + } + static OpenACCClause *addClause(OpenACCDirective *); + std::string toString(); + void mergeClause(OpenACCDirective *, OpenACCClause *); +}; + // Worker Clause class OpenACCWorkerClause : public OpenACCClause { protected: OpenACCWorkerClauseModifier modifier = ACCC_WORKER_unspecified; + std::string num_expr; public: OpenACCWorkerClause() : OpenACCClause(ACCC_worker){}; - OpenACCWorkerClauseModifier getModifier() { return modifier; }; + OpenACCWorkerClauseModifier getModifier() const { return modifier; }; void setModifier(OpenACCWorkerClauseModifier _modifier) { modifier = _modifier; }; + void setNumExpr(const std::string &expr) { num_expr = expr; } + const std::string &getNumExpr() const { return num_expr; } static OpenACCClause *addClause(OpenACCDirective *); std::string toString(); diff --git a/src/OpenACCIRToString.cpp b/src/OpenACCIRToString.cpp index 7caea43..9a42fb1 100644 --- a/src/OpenACCIRToString.cpp +++ b/src/OpenACCIRToString.cpp @@ -28,11 +28,22 @@ std::string OpenACCDirective::generatePragmaString(std::string prefix, std::vector *clauses = this->getClausesInOriginalOrder(); if (clauses->size() != 0) { - std::vector::iterator iter; - for (iter = clauses->begin(); iter != clauses->end(); iter++) { - result += (*iter)->toString(); + bool first = true; + for (auto *clause : *clauses) { + std::string clause_str = clause->toString(); + // Strip trailing whitespace the clause printer may have added. + while (!clause_str.empty() && std::isspace(clause_str.back())) { + clause_str.pop_back(); + } + if (clause_str.empty()) { + continue; + } + if (!first) { + result += " "; + } + result += clause_str; + first = false; } - result = result.substr(0, result.size() - 1); } result += ending_symbol; @@ -85,9 +96,9 @@ std::string OpenACCDirective::toString() { break; case ACCD_routine: result += "routine "; - if (((OpenACCRoutineDirective *)this)->getName() != "") + if (!((OpenACCRoutineDirective *)this)->getName().empty()) { result += "(" + ((OpenACCRoutineDirective *)this)->getName() + ") "; - ; + } break; case ACCD_serial: result += "serial "; @@ -116,12 +127,13 @@ std::string OpenACCClause::expressionToString() { std::string result; std::vector *expr = this->getExpressions(); - if (expr != NULL) { - std::vector::iterator it; - for (it = expr->begin(); it != expr->end(); it++) { - result += *it + ", "; - }; - result = result.substr(0, result.size() - 2); + if (expr != NULL && !expr->empty()) { + for (auto it = expr->begin(); it != expr->end(); ++it) { + if (it != expr->begin()) { + result += ", "; + } + result += *it; + } } return result; @@ -133,69 +145,67 @@ std::string OpenACCClause::toString() { switch (this->getKind()) { case ACCC_async: - result += "async "; - break; + return static_cast(this)->OpenACCAsyncClause::toString(); case ACCC_attach: - result += "attach "; - break; + return static_cast(this) + ->OpenACCAttachClause::toString(); case ACCC_auto: result += "auto "; break; case ACCC_bind: - result += "bind "; - break; + return static_cast(this)->OpenACCBindClause::toString(); case ACCC_capture: result += "capture "; break; + case ACCC_delete: + return static_cast(this) + ->OpenACCDeleteClause::toString(); + case ACCC_detach: + return static_cast(this) + ->OpenACCDetachClause::toString(); case ACCC_collapse: - result += "collapse "; - break; + return static_cast(this) + ->OpenACCCollapseClause::toString(); case ACCC_copy: - if (!original_keyword.empty()) { - result += original_keyword + " "; - } else { - result += "copy "; - } - break; + return static_cast(this)->OpenACCCopyClause::toString(); + case ACCC_copyin: + return static_cast(this) + ->OpenACCCopyinClause::toString(); + case ACCC_copyout: + return static_cast(this) + ->OpenACCCopyoutClause::toString(); + case ACCC_create: + return static_cast(this) + ->OpenACCCreateClause::toString(); case ACCC_default_async: - result += "default_async "; - break; - case ACCC_delete: - result += "delete "; - break; - case ACCC_detach: - result += "detach "; - break; + return static_cast(this) + ->OpenACCDefaultAsyncClause::toString(); case ACCC_device: - result += "device "; - break; + return static_cast(this)->OpenACCDeviceClause::toString(); case ACCC_device_num: - result += "device_num "; - break; + return static_cast(this) + ->OpenACCDeviceNumClause::toString(); case ACCC_device_resident: - result += "device_resident "; - break; + return static_cast(this) + ->OpenACCDeviceResidentClause::toString(); case ACCC_device_type: - result += "device_type "; - break; + return static_cast(this) + ->OpenACCDeviceTypeClause::toString(); case ACCC_deviceptr: - result += "deviceptr "; - break; + return static_cast(this) + ->OpenACCDeviceptrClause::toString(); case ACCC_finalize: result += "finalize "; break; case ACCC_firstprivate: - result += "firstprivate "; - break; + return static_cast(this) + ->OpenACCFirstprivateClause::toString(); case ACCC_gang: - result += "gang "; - break; + return static_cast(this)->OpenACCGangClause::toString(); case ACCC_host: - result += "host "; - break; + return static_cast(this)->OpenACCHostClause::toString(); case ACCC_if: - result += "if "; - break; + return static_cast(this)->OpenACCIfClause::toString(); case ACCC_if_present: result += "if_present "; break; @@ -203,47 +213,44 @@ std::string OpenACCClause::toString() { result += "independent "; break; case ACCC_link: - result += "link "; - break; + return static_cast(this)->OpenACCLinkClause::toString(); case ACCC_nohost: result += "nohost "; break; case ACCC_no_create: - result += "no_create "; - break; + return static_cast(this) + ->OpenACCNoCreateClause::toString(); case ACCC_num_gangs: - result += "num_gangs "; - break; + return static_cast(this) + ->OpenACCNumGangsClause::toString(); case ACCC_num_workers: - result += "num_workers "; - break; + return static_cast(this) + ->OpenACCNumWorkersClause::toString(); case ACCC_present: - result += "present "; - break; + return static_cast(this) + ->OpenACCPresentClause::toString(); case ACCC_private: - result += "private "; - break; + return static_cast(this) + ->OpenACCPrivateClause::toString(); case ACCC_read: result += "read "; break; case ACCC_self: - result += "self "; - break; + return static_cast(this)->OpenACCSelfClause::toString(); case ACCC_seq: result += "seq "; break; case ACCC_tile: - result += "tile "; - break; + return static_cast(this)->OpenACCTileClause::toString(); case ACCC_update: result += "update "; break; case ACCC_use_device: - result += "use_device "; - break; + return static_cast(this) + ->OpenACCUseDeviceClause::toString(); case ACCC_vector_length: - result += "vector_length "; - break; + return static_cast(this) + ->OpenACCVectorLengthClause::toString(); case ACCC_wait: result += "wait "; break; @@ -251,13 +258,13 @@ std::string OpenACCClause::toString() { result += "write "; break; case ACCC_vector: - result += "vector "; - break; + return static_cast(this)->OpenACCVectorClause::toString(); case ACCC_worker: - result += "worker "; - break; + return static_cast(this)->OpenACCWorkerClause::toString(); default: - printf("The clause enum is not supported yet.\n"); + std::cerr << "Unsupported OpenACC clause kind in toString(): " + << this->getKind() << std::endl; + assert(false && "Unsupported OpenACC clause kind in toString()"); } std::string clause_string = "("; @@ -274,20 +281,52 @@ std::string OpenACCClause::toString() { return result; }; -std::string OpenACCCacheDirective::expressionToString() { +static std::string deviceTypeToString(OpenACCDeviceTypeKind kind) { + switch (kind) { + case ACCC_DEVICE_TYPE_host: + return "host"; + case ACCC_DEVICE_TYPE_any: + return "any"; + case ACCC_DEVICE_TYPE_multicore: + return "multicore"; + case ACCC_DEVICE_TYPE_default: + return "default"; + default: + return ""; + } +} - std::string result; - std::vector *expr = this->getExpressions(); - if (expr != NULL) { - std::vector::iterator it; - for (it = expr->begin(); it != expr->end(); it++) { - result += *it + ", "; - }; - result = result.substr(0, result.size() - 2); +std::string OpenACCDeviceTypeClause::toString() { + std::string result = "device_type"; + std::string clause_string = ""; + + bool first = true; + for (auto kind : device_types) { + std::string name = deviceTypeToString(kind); + if (name.empty()) { + continue; + } + if (!first) { + clause_string += ", "; + } + clause_string += name; + first = false; + } + for (const auto &raw : unknown_types) { + if (!first) { + clause_string += ", "; + } + clause_string += raw; + first = false; } + if (!clause_string.empty()) { + result += "(" + clause_string + ") "; + } else { + result += " "; + } return result; -}; +} std::string OpenACCCacheDirective::toString() { @@ -300,26 +339,56 @@ std::string OpenACCCacheDirective::toString() { break; default:; }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { + parameter_string += this->varsToString(); + if (!parameter_string.empty()) { result += "(" + parameter_string + ") "; - } else { - result += " "; + return result; } + // No vars: drop the trailing space when no modifier or vars were present. + result += " "; + return result; }; +std::string OpenACCCollapseClause::toString() { + + std::string result; + const auto &vals = getCounts(); + if (vals.empty()) { + result = "collapse "; + return result; + } + + for (const auto &val : vals) { + result += "collapse"; + result += "(" + val + ") "; + } + return result; +} + +std::string OpenACCAsyncClause::toString() { + + std::string result = "async"; + if (getModifier() == ACCC_ASYNC_expr && !getAsyncExpr().empty()) { + result += "(" + getAsyncExpr() + ") "; + } else { + result += " "; + } + return result; +} + std::string OpenACCWaitDirective::expressionToString() { std::string result; - std::vector *expr = this->getExpressions(); - if (expr != NULL) { - std::vector::iterator it; - for (it = expr->begin(); it != expr->end(); it++) { - result += *it + ", "; - }; - result = result.substr(0, result.size() - 2); + std::vector *expr = this->getAsyncIds(); + if (expr != NULL && !expr->empty()) { + for (auto it = expr->begin(); it != expr->end(); ++it) { + if (it != expr->begin()) { + result += ", "; + } + result += *it; + } } return result; @@ -329,7 +398,7 @@ std::string OpenACCWaitDirective::toString() { std::string result = "wait"; std::string parameter_string = ""; - if (this->getExpressions()->size() != 0) { + if (!this->getAsyncIds()->empty()) { result += "("; std::string devnum = this->getDevnum(); if (devnum != "") { @@ -347,10 +416,154 @@ std::string OpenACCWaitDirective::toString() { return result; }; +std::string OpenACCDefaultAsyncClause::toString() { + + std::string result = "default_async"; + if (!getAsyncExpr().empty()) { + result += "(" + getAsyncExpr() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCDeviceClause::toString() { + + std::string result = "device"; + const auto &devs = getDevices(); + if (!devs.empty()) { + result += "("; + for (auto it = devs.begin(); it != devs.end(); ++it) { + result += *it; + if (it + 1 != devs.end()) { + result += ", "; + } + } + result += ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCIfClause::toString() { + + std::string result = "if"; + if (!getCondition().empty()) { + result += "(" + getCondition() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCDeviceNumClause::toString() { + + std::string result = "device_num"; + if (!getDeviceExpr().empty()) { + result += "(" + getDeviceExpr() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCGangClause::toString() { + + std::string result = "gang"; + if (getArgs().empty()) { + result += " "; + return result; + } + + std::string parameter_string; + const auto &arg_list = getArgs(); + for (auto it = arg_list.begin(); it != arg_list.end(); ++it) { + if (it != arg_list.begin()) { + parameter_string += ", "; + } + switch (it->kind) { + case ACCC_GANG_ARG_num: + parameter_string += "num:" + it->value; + break; + case ACCC_GANG_ARG_dim: + parameter_string += "dim:" + it->value; + break; + case ACCC_GANG_ARG_static: + parameter_string += "static:" + it->value; + break; + default: + parameter_string += it->value; + break; + } + } + result += "(" + parameter_string + ") "; + return result; +} + +std::string OpenACCNumGangsClause::toString() { + + std::string result = "num_gangs"; + const auto &vals = getNums(); + if (!vals.empty()) { + result += "("; + for (auto it = vals.begin(); it != vals.end(); ++it) { + result += *it; + if (it + 1 != vals.end()) + result += ", "; + } + result += ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCNumWorkersClause::toString() { + + std::string result = "num_workers"; + if (!getNumExpr().empty()) { + result += "(" + getNumExpr() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCTileClause::toString() { + + std::string result = "tile"; + const auto &sizes = getTileSizes(); + if (!sizes.empty()) { + result += "("; + for (auto it = sizes.begin(); it != sizes.end(); ++it) { + result += *it; + if (it + 1 != sizes.end()) { + result += ", "; + } + } + result += ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCBindClause::toString() { + + std::string result = "bind"; + if (!getBinding().empty()) { + result += "(" + getBinding() + ") "; + } else { + result += " "; + } + return result; +} + std::string OpenACCCopyinClause::toString() { std::string keyword = original_keyword.empty() ? "copyin" : original_keyword; - std::string result = keyword + "("; + std::string result = keyword; std::string parameter_string = ""; OpenACCCopyinClauseModifier modifier = this->getModifier(); switch (modifier) { @@ -359,11 +572,11 @@ std::string OpenACCCopyinClause::toString() { break; default:; }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += parameter_string + ") "; + parameter_string += this->varsToString(); + if (!parameter_string.empty()) { + result += "(" + parameter_string + ") "; } else { - result = result.substr(0, result.size() - 1); + result += " "; } return result; @@ -372,7 +585,7 @@ std::string OpenACCCopyinClause::toString() { std::string OpenACCCopyoutClause::toString() { std::string keyword = original_keyword.empty() ? "copyout" : original_keyword; - std::string result = keyword + "("; + std::string result = keyword; std::string parameter_string = ""; OpenACCCopyoutClauseModifier modifier = this->getModifier(); switch (modifier) { @@ -381,11 +594,11 @@ std::string OpenACCCopyoutClause::toString() { break; default:; }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += parameter_string + ") "; + parameter_string += this->varsToString(); + if (!parameter_string.empty()) { + result += "(" + parameter_string + ") "; } else { - result = result.substr(0, result.size() - 1); + result += " "; } return result; @@ -394,7 +607,7 @@ std::string OpenACCCopyoutClause::toString() { std::string OpenACCCreateClause::toString() { std::string keyword = original_keyword.empty() ? "create" : original_keyword; - std::string result = keyword + "("; + std::string result = keyword; std::string parameter_string = ""; OpenACCCreateClauseModifier modifier = this->getModifier(); switch (modifier) { @@ -403,16 +616,147 @@ std::string OpenACCCreateClause::toString() { break; default:; }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += parameter_string + ") "; + parameter_string += this->varsToString(); + if (!parameter_string.empty()) { + result += "(" + parameter_string + ") "; } else { - result = result.substr(0, result.size() - 1); + result += " "; } return result; }; +static std::string +varClauseToString(const std::string &keyword, + const std::vector &vars) { + std::string result = keyword; + if (!vars.empty()) { + result += "("; + for (auto it = vars.begin(); it != vars.end(); ++it) { + result += *it; + if (it + 1 != vars.end()) { + result += ", "; + } + } + result += ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCCopyClause::toString() { + std::string keyword = original_keyword.empty() ? "copy" : original_keyword; + std::string result = keyword; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCNoCreateClause::toString() { + std::string result = "no_create"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCPresentClause::toString() { + std::string result = "present"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCLinkClause::toString() { + std::string result = "link"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCDeviceResidentClause::toString() { + std::string result = "device_resident"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCDeviceptrClause::toString() { + std::string result = "deviceptr"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCUseDeviceClause::toString() { + std::string result = "use_device"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCAttachClause::toString() { + std::string result = "attach"; + if (!getVars().empty()) { + result += "(" + varsToString() + ") "; + } else { + result += " "; + } + return result; +} + +std::string OpenACCDeleteClause::toString() { + return varClauseToString("delete", getVars()); +} + +std::string OpenACCDetachClause::toString() { + return varClauseToString("detach", getVars()); +} + +std::string OpenACCSelfClause::toString() { + std::string result = "self"; + if (!getCondition().empty()) { + result += "(" + getCondition() + ")"; + } else if (!getVars().empty()) { + result += "(" + varsToString() + ")"; + } + result += " "; + return result; +} + +std::string OpenACCFirstprivateClause::toString() { + return varClauseToString("firstprivate", getVars()); +} + +std::string OpenACCHostClause::toString() { + return varClauseToString("host", getVars()); +} + +std::string OpenACCPrivateClause::toString() { + return varClauseToString("private", getVars()); +} + std::string OpenACCDefaultClause::toString() { std::string result = "default"; @@ -441,68 +785,68 @@ std::string OpenACCDefaultClause::toString() { std::string OpenACCReductionClause::toString() { std::string result = "reduction("; - std::string parameter_string = ""; + std::string op; OpenACCReductionClauseOperator reduction_operator = this->getOperator(); switch (reduction_operator) { case ACCC_REDUCTION_add: - parameter_string = "+ : "; + op = "+"; break; case ACCC_REDUCTION_sub: - parameter_string = "- : "; + op = "-"; break; case ACCC_REDUCTION_mul: - parameter_string = "* : "; + op = "*"; break; case ACCC_REDUCTION_max: - parameter_string = "max : "; + op = "max"; break; case ACCC_REDUCTION_min: - parameter_string = "min : "; + op = "min"; break; case ACCC_REDUCTION_bitand: - parameter_string = "& : "; + op = "&"; break; case ACCC_REDUCTION_bitor: - parameter_string = "| : "; + op = "|"; break; case ACCC_REDUCTION_bitxor: - parameter_string = "^ : "; + op = "^"; break; case ACCC_REDUCTION_logand: - parameter_string = "&& : "; + op = "&&"; break; case ACCC_REDUCTION_logor: - parameter_string = "|| : "; + op = "||"; break; case ACCC_REDUCTION_fort_and: - parameter_string = ".and. : "; + op = ".and."; break; case ACCC_REDUCTION_fort_or: - parameter_string = ".or. : "; + op = ".or."; break; case ACCC_REDUCTION_fort_eqv: - parameter_string = ".eqv. : "; + op = ".eqv."; break; case ACCC_REDUCTION_fort_neqv: - parameter_string = ".neqv. : "; + op = ".neqv."; break; case ACCC_REDUCTION_fort_iand: - parameter_string = "iand : "; + op = ".iand."; break; case ACCC_REDUCTION_fort_ior: - parameter_string = "ior : "; + op = ".ior."; break; case ACCC_REDUCTION_fort_ieor: - parameter_string = "ieor : "; + op = ".ieor."; break; - default:; + default: + op = ""; }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += parameter_string + ") "; - } else { - result = result.substr(0, result.size() - 1); - } + + result += op; + result += " : "; + result += this->varsToString(); + result += ") "; return result; }; @@ -510,29 +854,42 @@ std::string OpenACCReductionClause::toString() { std::string OpenACCVectorClause::toString() { std::string result = "vector"; - std::string parameter_string = ""; - OpenACCVectorClauseModifier modifier = this->getModifier(); - switch (modifier) { + const std::string &length = this->getLengthExpr(); + switch (this->getModifier()) { case ACCC_VECTOR_length: - parameter_string = "length: "; + result += "(length: " + length + ") "; break; - default:; - }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += "(" + parameter_string + ") "; - } else { - result += " "; + case ACCC_VECTOR_expr_only: + result += "(" + length + ") "; + break; + default: + if (!length.empty()) { + result += "(" + length + ") "; + } else { + result += " "; + } } return result; }; +std::string OpenACCVectorLengthClause::toString() { + + std::string result = "vector_length"; + const std::string &length = this->getLengthExpr(); + if (!length.empty()) { + result += "(" + length + ") "; + } else { + result += " "; + } + return result; +} + std::string OpenACCWaitClause::toString() { std::string result = "wait"; std::string parameter_string = ""; - if (this->getExpressions()->size() != 0) { + if (!this->getAsyncIds().empty()) { result += "("; std::string devnum = this->getDevnum(); if (devnum != "") { @@ -542,7 +899,13 @@ std::string OpenACCWaitClause::toString() { parameter_string += "queues: "; }; - parameter_string += this->expressionToString(); + const auto &ids = this->getAsyncIds(); + for (auto it = ids.begin(); it != ids.end(); ++it) { + parameter_string += *it; + if (it + 1 != ids.end()) { + parameter_string += ", "; + } + } result += parameter_string + ") "; } else { result += " "; @@ -554,19 +917,20 @@ std::string OpenACCWaitClause::toString() { std::string OpenACCWorkerClause::toString() { std::string result = "worker"; - std::string parameter_string = ""; - OpenACCWorkerClauseModifier modifier = this->getModifier(); - switch (modifier) { + const std::string &num = this->getNumExpr(); + switch (this->getModifier()) { case ACCC_WORKER_num: - parameter_string = "num: "; + result += "(num: " + num + ") "; break; - default:; - }; - parameter_string += this->expressionToString(); - if (parameter_string.size() > 0) { - result += "(" + parameter_string + ") "; - } else { - result += " "; + case ACCC_WORKER_expr_only: + result += "(" + num + ") "; + break; + default: + if (!num.empty()) { + result += "(" + num + ") "; + } else { + result += " "; + } } return result; diff --git a/src/OpenACCKinds.h b/src/OpenACCKinds.h index 9652066..b781a35 100644 --- a/src/OpenACCKinds.h +++ b/src/OpenACCKinds.h @@ -89,6 +89,14 @@ enum OpenACCClauseKind { #undef OPENACC_CLAUSE }; +enum OpenACCDeviceTypeKind { + ACCC_DEVICE_TYPE_unknown, + ACCC_DEVICE_TYPE_host, + ACCC_DEVICE_TYPE_any, + ACCC_DEVICE_TYPE_multicore, + ACCC_DEVICE_TYPE_default +}; + // OpenACC attributes for 'cache' directive. enum OpenACCCacheDirectiveModifier { #define OPENACC_CACHE_MODIFIER(Name) ACCC_CACHE_##Name, @@ -166,6 +174,7 @@ enum OpenACCVectorClauseModifier { #define OPENACC_VECTOR_MODIFIER(Name) ACCC_VECTOR_##Name, OPENACC_VECTOR_MODIFIER(unspecified) OPENACC_VECTOR_MODIFIER(length) + OPENACC_VECTOR_MODIFIER(expr_only) OPENACC_VECTOR_MODIFIER(unknown) #undef OPENACC_VECTOR_MODIFIER }; @@ -175,9 +184,22 @@ enum OpenACCWorkerClauseModifier { #define OPENACC_WORKER_MODIFIER(Name) ACCC_WORKER_##Name, OPENACC_WORKER_MODIFIER(unspecified) OPENACC_WORKER_MODIFIER(num) + OPENACC_WORKER_MODIFIER(expr_only) OPENACC_WORKER_MODIFIER(unknown) #undef OPENACC_WORKER_MODIFIER }; -#endif +enum OpenACCAsyncModifier { + ACCC_ASYNC_unspecified, + ACCC_ASYNC_expr +}; +enum OpenACCGangArgKind { + ACCC_GANG_ARG_unknown, + ACCC_GANG_ARG_num, + ACCC_GANG_ARG_dim, + ACCC_GANG_ARG_static, + ACCC_GANG_ARG_other +}; + +#endif diff --git a/src/acclexer.g4 b/src/acclexer.g4 index ea49d7a..66527f1 100644 --- a/src/acclexer.g4 +++ b/src/acclexer.g4 @@ -159,6 +159,11 @@ LINE_END : [\n\r] -> skip ; +CLAUSE_COMMA + : ',' [\p{White_Space}]* + -> skip + ; + ASYNC : 'async' [\p{White_Space}]* { @@ -1299,4 +1304,3 @@ EXPRESSION_CHAR } } ; - diff --git a/tests/test_single_pragma.sh b/tests/test_single_pragma.sh index b64285d..8327da5 100755 --- a/tests/test_single_pragma.sh +++ b/tests/test_single_pragma.sh @@ -44,10 +44,34 @@ while IFS= read -r pragma || [ -n "$pragma" ]; do continue fi - # Try to parse the pragma - if ! "$ROUNDTRIP_BIN" "$pragma" >/dev/null 2>&1; then + # Try to parse and unparse the pragma + if ! roundtrip_output="$("$ROUNDTRIP_BIN" "$pragma" 2>/dev/null)"; then echo "Parse failed at line $line_num: $pragma" failed=$((failed + 1)) + continue + fi + + # Normalise whitespace for comparison so formatting differences don't mask + # structural mismatches. + normalize() { + echo "$1" | tr '\n' ' ' \ + | sed -e 's/[[:space:]]\+/ /g' \ + -e 's/ *: */:/g' \ + -e 's/, */,/g' \ + -e 's/ ,/,/g' \ + -e 's/) *, */) /g' \ + -e 's/ *( */(/g' \ + -e 's/^ //;s/ $//' + } + + norm_input=$(normalize "$pragma") + norm_output=$(normalize "$roundtrip_output") + + if [ "$norm_input" != "$norm_output" ]; then + echo "Roundtrip mismatch at line $line_num:" + echo " input : $norm_input" + echo " output: $norm_output" + failed=$((failed + 1)) fi done < "$PRAGMA_FILE"