Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class ReachingDefinitionsAnalysis
friend class ::mlir::DataFlowSolver;
ReachingDefinitionsAnalysis(
DataFlowSolver &solver,
llvm::function_ref<bool(Operation *)> definitionFilter,
llvm::function_ref<bool(Operation *)> definitionFilter = {},
llvm::function_ref<LogicalResult(InstOpInterface, KillDefsFn)>
killCallback = {})
: Base(solver), definitionFilter(definitionFilter),
Expand Down
14 changes: 14 additions & 0 deletions include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def ValueSchedulerAttr : AMDGCN_Attr<"ValueScheduler", "value_scheduler", [
}];
}

def RegisterSchedulerAttr : AMDGCN_Attr<"RegisterScheduler", "register_scheduler", [
DeclareAttrInterfaceMethods<SchedGraphAttrInterface>
]> {
let summary = "Register-aware scheduling graph builder for AMDGCN instructions";
let description = [{
Like the value scheduler, builds SSA edges, wait and barrier edges, and i1
serialization. Additionally, after non-SSA edges, adds dependencies from
reaching definitions for register operands on `InstOpInterface` ins and outs.

Requires IR in post-ToRegisterSemantics DPS normal form (no value-semantic
`outs` on instructions), the same precondition as `ReachingDefinitionsAnalysis`.
}];
}

def InstPropLabelerAttr : AMDGCN_Attr<"InstPropLabeler", "inst_prop_labeler", [
DeclareAttrInterfaceMethods<SchedLabelerAttrInterface>
]> {
Expand Down
1 change: 0 additions & 1 deletion include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ def AMDGCN_TestInstOp : AMDGCN_Op<"test_inst", [

/// Get the opcode of the instruction.
InstAttr getOpcodeAttr() {
assert(false && "not yet implemented");
return InstAttr();
}
/// Get the instruction output operands.
Expand Down
2 changes: 2 additions & 0 deletions include/aster/Dialect/AMDGCN/Transforms/AMDGCNPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def LowLevelScheduler : Pass<"amdgcn-low-level-scheduler",
let options = [
Option<"debugStalls", "debug-stalls", "bool", "false",
"Annotate each op with sched.stall_cycles and sched.stall_reason">,
Option<"registerSemantics", "register-semantics", "bool", "false",
"Use register semantics for scheduling dependencies">,
];
}

Expand Down
105 changes: 101 additions & 4 deletions lib/Dialect/AMDGCN/IR/SchedAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
//
//===----------------------------------------------------------------------===//

#include "aster/Dialect/AMDGCN/Analysis/ReachingDefinitions.h"
#include "aster/Dialect/AMDGCN/Analysis/WaitAnalysis.h"
#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h"
#include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h"
#include "aster/Dialect/LSIR/IR/LSIROps.h"
#include "aster/Interfaces/SchedInterfaces.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -29,8 +32,9 @@ using namespace mlir::aster::amdgcn;

namespace {
struct GraphBuilder {
GraphBuilder(Block *block, const DataFlowSolver &solver)
: block(block), solver(const_cast<DataFlowSolver &>(solver)) {
GraphBuilder(Block *block, const DataFlowSolver &solver, bool useRegisterDeps)
: block(block), solver(const_cast<DataFlowSolver &>(solver)),
useRegisterDeps(useRegisterDeps) {
assert(block && "expected a valid block");
}

Expand All @@ -44,6 +48,9 @@ struct GraphBuilder {
/// Build the non-SSA dependencies for the graph.
void buildNonSSADeps(SchedGraph &graph);

/// Add edges from reaching definitions for DPS register ins/outs.
void buildRegisterDeps(SchedGraph &graph);

/// Handle a wait operation.
void handleWaitOp(SchedGraph &graph, int64_t pos, WaitOp wait);

Expand All @@ -57,6 +64,7 @@ struct GraphBuilder {
Block *block;
SmallVector<int64_t> syncPoints;
DataFlowSolver &solver;
bool useRegisterDeps;
};
} // namespace

Expand All @@ -71,6 +79,8 @@ ValueSchedulerAttr::initializeAnalyses(SchedAnalysis &analysis) const {
LogicalResult GraphBuilder::run(SchedGraph &graph) {
buildSSADeps(graph);
buildNonSSADeps(graph);
if (useRegisterDeps)
buildRegisterDeps(graph);
addI1SerializationEdges(graph);
return success();
}
Expand All @@ -86,9 +96,27 @@ void GraphBuilder::buildSSADeps(SchedGraph &graph) {
bool hasEffects = op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() ||
op->hasTrait<MemoryEffectOpInterface::Trait>();

bool isPureOp = mlir::isPure(op);

// If we're using register dependencies, then treat effects on register
// resources as non-effects.
if (useRegisterDeps && isa<InstOpInterface>(op) &&
!isa<lsir::SelectOp, lsir::CmpIOp>(op)) {
auto eOp = dyn_cast<MemoryEffectOpInterface>(op);
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
if (eOp) {
eOp.getEffects(effects);
isPureOp = llvm::all_of(effects, [](const SideEffects::EffectInstance<
MemoryEffects::Effect> &effect) {
return isa<VGPRResource, SGPRResource, AGPRResource, SREGResource>(
effect.getResource());
});
}
}

// If the operation has no side-effect we need to treat it as a possible
// sync point. Same for non-pure operations.
if ((!hasEffects || !mlir::isPure(op)) &&
if ((!hasEffects || !isPureOp) &&
!isa<LoadOp, StoreOp, AllocaOpInterface>(op)) {
LDBG() << "Adding sync point: " << i;
syncPoints.push_back(i);
Expand Down Expand Up @@ -144,6 +172,52 @@ void GraphBuilder::buildNonSSADeps(SchedGraph &graph) {
}
}

void GraphBuilder::buildRegisterDeps(SchedGraph &graph) {
// Helper function to add edges from reaching definitions for register
// operands.
auto addEdges = [&](Operation *op, int64_t opId, ValueRange values,
const ReachingDefinitionsState *beforeState) {
for (Value value : values) {
FailureOr<ValueRange> allocasOrFailure = getAllocasOrFailure(value);
assert(succeeded(allocasOrFailure) && "expected valid allocas");
for (Value alloc : *allocasOrFailure) {
for (const Definition &def : beforeState->getRange(alloc)) {
assert(def.definition && "expected valid definition");
Operation *producer = def.definition->getOwner();
int64_t pOpId = graph.getOpId(producer);
if (pOpId >= 0 && pOpId < opId)
graph.addEdge(producer, op);
}
}
}
};
for (auto [i, op] : llvm::enumerate(graph.getOps())) {
auto instOp = dyn_cast<InstOpInterface>(op);
// Skip non-InstOpInterface operations.
if (!instOp)
continue;

const auto *beforeState = solver.lookupState<ReachingDefinitionsState>(
solver.getProgramPointBefore(op));
assert(beforeState && "expected valid reaching definitions state");
ValueRange outs = instOp.getInstOuts();
ValueRange ins = instOp.getInstIns();
addEdges(op, i, ins, beforeState);
// Make sure we never clobber.
addEdges(op, i, outs, beforeState);
for (Operation *pOp : graph.getOps().take_front(i)) {
ValueRange prevVals = pOp->getOperands();
if (llvm::any_of(prevVals, [&](Value val) {
return llvm::is_contained(outs, val);
}))
graph.addEdge(pOp, op);
if (llvm::any_of(prevVals,
[&](Value val) { return llvm::is_contained(ins, val); }))
graph.addEdge(pOp, op);
}
}
}

void GraphBuilder::handleWaitOp(SchedGraph &graph, int64_t pos, WaitOp wait) {
// Get the wait state.
const WaitState *state =
Expand Down Expand Up @@ -305,7 +379,30 @@ FailureOr<SchedGraph>
ValueSchedulerAttr::createGraph(Block *block,
const SchedAnalysis &analysis) const {
SchedGraph graph(block);
GraphBuilder builder(block, analysis.getSolver());
GraphBuilder builder(block, analysis.getSolver(), /*useRegisterDeps=*/false);
if (failed(builder.run(graph)))
return failure();
graph.compress();
return graph;
}

//===----------------------------------------------------------------------===//
// RegisterSchedulerAttr - SchedGraphAttrInterface
//===----------------------------------------------------------------------===//

LogicalResult
RegisterSchedulerAttr::initializeAnalyses(SchedAnalysis &analysis) const {
analysis.getSolver().load<WaitAnalysis>(analysis.getDomInfo());
analysis.getSolver().load<ReachingDefinitionsAnalysis>();
analysis.setRunDataflowAnalyses();
return success();
}

FailureOr<SchedGraph>
RegisterSchedulerAttr::createGraph(Block *block,
const SchedAnalysis &analysis) const {
SchedGraph graph(block);
GraphBuilder builder(block, analysis.getSolver(), /*useRegisterDeps=*/true);
if (failed(builder.run(graph)))
return failure();
graph.compress();
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ struct LowLevelSchedulerPass
/*emitDiagnostics=*/true)))
return signalPassFailure();

SchedGraphAttrInterface builderAttr = ValueSchedulerAttr::get(ctx);
if (registerSemantics)
builderAttr = RegisterSchedulerAttr::get(ctx);
GenericSchedulerAttr compositeAttr = GenericSchedulerAttr::get(
ctx, ValueSchedulerAttr::get(ctx),
ctx, builderAttr,
SchedListLabelerAttr::get(ctx, ArrayRef<SchedLabelerAttrInterface>{}),
LowLevelSchedulerAttr::get(ctx, debugStalls));

Expand Down
8 changes: 2 additions & 6 deletions lib/Dialect/AMDGCN/Transforms/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ static void buildRegAllocPassPipeline(OpPassManager &pm,
pm.addPass(createHoistIterArgWaits());
pm.addPass(createCanonicalizerPass());
}
if (options.llSched)
pm.addPass(createLowLevelScheduler());
pm.addPass(createAMDGCNBufferization());
if (options.hoistIterArgWaits) {
pm.addPass(createHoistIterArgWaits());
pm.addPass(createCanonicalizerPass());
}
pm.addPass(createToRegisterSemantics());
// if (options.llSched)
pm.addPass(createLowLevelScheduler({false, true}));
// Post-condition of to-register-semantics is now enforced by
// KernelOp::verifyRegions() via the normal_forms attribute set by the pass.
pm.addPass(createRegisterDCE());
Expand Down
Loading