Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions arbor/include/arbor/load_balance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ using partition_hint_map = std::unordered_map<cell_kind, partition_hint>;
ARB_ARBOR_API domain_decomposition_ptr partition_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map = {});

ARB_ARBOR_API domain_decomposition_ptr round_robin_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map = {});

} // namespace arb
128 changes: 81 additions & 47 deletions arbor/partition_load_balance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "cell_group_factory.hpp"
#include "execution_context.hpp"
#include "util/maputil.hpp"
#include "util/partition.hpp"
#include "util/span.hpp"
#include "util/strprintf.hpp"

Expand All @@ -21,9 +20,17 @@ namespace arb {
namespace {
using gj_connection_set = std::unordered_set<cell_gid_type>;
using gj_connection_table = std::unordered_map<cell_gid_type, gj_connection_set>;
using gid_range = std::pair<cell_gid_type, cell_gid_type>;
using super_cell = std::vector<cell_gid_type>;

struct gid_range {
cell_gid_type beg = 0;
cell_gid_type end = 0;
cell_gid_type dlt = 1;
};

// a (stepped) range contains all gids lo <= gid < hi that are multiples of the step
bool contains_gid(const gid_range& gids, cell_gid_type gid) { return (gid >= gids.beg) && (gid < gids.end) && (gid % gids.dlt == 0); }

// Build global GJ connectivity table such that
// * table[gid] is the set of all gids connected to gid via a GJ
// * iff A in table[B], then B in table[A]
Expand Down Expand Up @@ -55,66 +62,73 @@ auto make_local_gid_range(context ctx, cell_gid_type num_global_cells) {
// all previous domains, incl ours, have an extra element
auto beg = domain_id*(block + 1);
auto end = beg + block + 1;
return std::make_pair(beg, end);
return gid_range { .beg=beg, .end=end, .dlt=1 };
}
else {
// in this case the first `extra` domains added an extra element and the
// rest has size `block`
auto beg = extra + domain_id*block;
auto end = beg + block;
return std::make_pair(beg, end);
return gid_range { .beg=beg, .end=end, .dlt=1 };
}
}

// assign gids round-robin, i.e. on N ranks, rank i gets gids [i, i + N, ...]
auto make_round_robin_gids(context ctx, cell_gid_type num_global_cells) {
const auto& dist = ctx->distributed;
unsigned num_domains = dist->size();
unsigned domain_id = dist->id();
return gid_range { .beg=domain_id, .end=num_global_cells, .dlt=num_domains };
}

// build the list of components for the local domain, where a component is a list of
// cell gids such that
// * the smallest gid in the list is in the local_gid_range
// * all gids that are connected to the smallest gid are also in the list
// * all gids w/o GJ connections come first (for historical reasons!?)
auto build_components(const gj_connection_table& global_gj_connection_table,
gid_range local_gid_range) {
// cells connected by gj
std::vector<super_cell> super_cells;
const gid_range& gids) {
// singular cells
std::vector<super_cell> res;
for (cell_gid_type gid = gids.beg; gid < gids.end; gid += gids.dlt) {
if (!global_gj_connection_table.count(gid)) {
res.push_back({gid});
}
}

// cells connected by gj
// track visited cells (cells that already belong to a group)
gj_connection_set visited;
// Connected components via BFS
std::vector<cell_gid_type> q;
for (auto gid: util::make_span(local_gid_range)) {
if (global_gj_connection_table.count(gid)) {
// If cell hasn't been visited yet, must belong to new component
if (visited.insert(gid).second) {
// pivot gid: the smallest found in this group; must be at
// smaller or equal to `gid`.
auto min_gid = gid;
q.push_back(gid);
super_cell sc;
while (!q.empty()) {
auto element = q.back();
q.pop_back();
sc.push_back(element);
min_gid = std::min(element, min_gid);
// queue up conjoined cells
for (const auto& peer: global_gj_connection_table.at(element)) {
if (visited.insert(peer).second) q.push_back(peer);
}
}
// if the pivot gid belongs to our domain, this group will be part
// of our domain, keep it and sort.
if (min_gid >= local_gid_range.first) {
std::sort(sc.begin(), sc.end());
super_cells.emplace_back(std::move(sc));
}
for (cell_gid_type gid = gids.beg; gid < gids.end; gid += gids.dlt) {
// not in a GJ compound cell, skip
if (!global_gj_connection_table.count(gid)) continue;
// cell has been visited, skip
if (!visited.insert(gid).second) continue;
// pivot gid: the smallest found in this group; must be at smaller
// or equal to `gid`. We use this to determine whether this
// component is ours
auto min_gid = gid;
q.push_back(gid);
super_cell sc;
while (!q.empty()) {
auto element = q.back();
q.pop_back();
sc.push_back(element);
min_gid = std::min(element, min_gid);
// queue up conjoined cells
for (const auto& peer: global_gj_connection_table.at(element)) {
if (visited.insert(peer).second) q.push_back(peer);
}
}
else {
res.push_back({gid});
// if the pivot gid belongs to our domain, this group will be part
// of our domain: sort and add to result
if (contains_gid(gids, min_gid)) {
std::sort(sc.begin(), sc.end());
res.emplace_back(std::move(sc));
}
}
// append super cells to result
res.reserve(res.size() + super_cells.size());
std::move(super_cells.begin(), super_cells.end(), std::back_inserter(res));
return res;
}

Expand Down Expand Up @@ -163,18 +177,24 @@ auto build_group_parameters(context ctx,

// Build the list of GJ-connected cells local to this domain.
// NOTE We put this into its own function to avoid increasing RSS.
auto build_local_components(const recipe& rec, context ctx) {
auto build_local_components_by_range(const recipe& rec, context ctx) {
const auto global_gj_connection_table = build_global_gj_connection_table(rec);
const auto local_gid_range = make_local_gid_range(ctx, rec.num_cells());
return build_components(global_gj_connection_table, local_gid_range);
const auto local_gids = make_local_gid_range(ctx, rec.num_cells());
return build_components(global_gj_connection_table, local_gids);
}

} // namespace
auto build_local_components_by_round_robin(const recipe& rec, context ctx) {
const auto global_gj_connection_table = build_global_gj_connection_table(rec);
const auto local_gids = make_round_robin_gids(ctx, rec.num_cells());
return build_components(global_gj_connection_table, local_gids);
}

ARB_ARBOR_API domain_decomposition_ptr partition_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map) {
const auto components = build_local_components(rec, ctx);
template<typename F>
domain_decomposition_ptr do_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map,
F&& component_builder) {
const auto components = component_builder(rec, ctx);

std::vector<cell_gid_type> local_gids;
std::unordered_map<cell_kind, std::vector<cell_gid_type>> kind_lists;
Expand All @@ -196,8 +216,8 @@ ARB_ARBOR_API domain_decomposition_ptr partition_load_balance(const recipe& rec,
for (const auto& params: kinds) {
std::vector<cell_gid_type> group_elements;
// group_elements are sorted such that the gids of all members of a component are consecutive.
for (auto cell: kind_lists[params.kind]) {
const auto& component = components[cell];
for (auto cell_idx: kind_lists[params.kind]) {
const auto& component = components[cell_idx];
// adding the current group would go beyond alloted size, so add to the list
// of groups and start a new one.
if (group_elements.size() + component.size() > params.size && !group_elements.empty()) {
Expand All @@ -213,5 +233,19 @@ ARB_ARBOR_API domain_decomposition_ptr partition_load_balance(const recipe& rec,
}

return std::make_shared<domain_decomposition>(rec, ctx, groups);
}
} // namespace

ARB_ARBOR_API domain_decomposition_ptr partition_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map) {
return do_load_balance(rec, ctx, hint_map, build_local_components_by_range);
}

ARB_ARBOR_API domain_decomposition_ptr round_robin_load_balance(const recipe& rec,
context ctx,
const partition_hint_map& hint_map) {
return do_load_balance(rec, ctx, hint_map, build_local_components_by_round_robin);
}

} // namespace arb
11 changes: 5 additions & 6 deletions arbor/profile/profiler.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cstdio>
#include <mutex>
#include <ostream>
#include <sstream>
#include <utility>

#include <arbor/context.hpp>
Expand Down Expand Up @@ -152,13 +153,11 @@ const accumulators_type& recorder::accumulators() const {
}

std::string timer_stack_to_string(const timer_stack& ts, const std::vector<std::string>& names) {
std::stringstream ss{};
for (auto i=0U;i<ts.size();i++) {
const auto timer = ts[i];
std::stringstream ss;
for (auto ix = 0U; ix < ts.size(); ++ix) {
const auto timer = ts[ix];
ss << names[timer];
if (i != ts.size()-1) {
ss << ", ";
}
if (ix != ts.size() - 1) ss << ", ";
}
return ss.str();
}
Expand Down
44 changes: 30 additions & 14 deletions python/domain_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,36 @@ void register_domain_decomposition(pybind11::module& m) {
// Partition load balancer
// The Python recipe has to be shimmed for passing to the function that takes a C++ recipe.
m.def("partition_load_balance",
[](std::shared_ptr<recipe>& recipe, const context_shim& ctx, arb::partition_hint_map hint_map) {
try {
return arb::partition_load_balance(recipe_shim(recipe), ctx.context, std::move(hint_map));
}
catch (...) {
py_reset_and_throw();
throw;
}
},
"Construct a domain_decomposition that distributes the cells in the model described by recipe\n"
"over the distributed and local hardware resources described by context.\n"
"Optionally, provide a dictionary of partition hints for certain cell kinds, by default empty.",
"recipe"_a, "context"_a, "hints"_a=arb::partition_hint_map{});

[](std::shared_ptr<recipe>& recipe, const context_shim& ctx, arb::partition_hint_map hint_map) {
try {
return arb::partition_load_balance(recipe_shim(recipe), ctx.context, std::move(hint_map));
}
catch (...) {
py_reset_and_throw();
throw;
}
},
"Construct a domain_decomposition that distributes the cells in the model described by recipe\n"
"over the distributed and local hardware resources described by context in consecutive blocks.\n"
"Optionally, provide a dictionary of partition hints for certain cell kinds, by default empty.",
"recipe"_a, "context"_a, "hints"_a=arb::partition_hint_map{});

// Round-robin load balancer
m.def("round_robin_load_balance",
[](std::shared_ptr<recipe>& recipe, const context_shim& ctx, arb::partition_hint_map hint_map) {
try {
return arb::round_robin_load_balance(recipe_shim(recipe), ctx.context, std::move(hint_map));
}
catch (...) {
py_reset_and_throw();
throw;
}
},
"Construct a domain_decomposition that distributes the cells in the model described by recipe\n"
"over the distributed and local hardware resources described by context by round-robin assignment.\n"
"Optionally, provide a dictionary of partition hints for certain cell kinds, by default empty.",
"recipe"_a, "context"_a, "hints"_a=arb::partition_hint_map{});

m.def("partition_by_group",
[](std::shared_ptr<recipe>& recipe, const context_shim& ctx, const std::vector<arb::group_description>& groups) {
try {
Expand Down
Loading
Loading