Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,20 @@ def PTO_RoundModeEnum : PTO_I32Enum<
def PTO_RoundModeAttr : EnumAttr<PTO_Dialect, PTO_RoundModeEnum, "round_mode"> {
let summary = "rounding mode attribute";
}

//===----------------------------------------------------------------------===//
// SaturationMode
//===----------------------------------------------------------------------===//

def PTO_SaturationModeEnum : PTO_I32Enum<
"SaturationMode", "PTO saturation mode", [
I32EnumAttrCase<"ON", 0>,
I32EnumAttrCase<"OFF", 1>
]>;

def PTO_SaturationModeAttr : EnumAttr<PTO_Dialect, PTO_SaturationModeEnum, "saturation_mode"> {
let summary = "saturation mode attribute";
}
//===----------------------------------------------------------------------===//
// TStore template controls
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 38 additions & 8 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def TMatmulMxAccOp : PTO_TOp<"tmatmul.mx.acc", [
}];

let extraClassDeclaration = [{
static StringRef getIntrinsicName() { return "TMATMUL_MX_ACC"; }
static StringRef getIntrinsicName() { return "TMATMUL_MX"; }
::mlir::pto::PIPE getPipe() {
return ::mlir::pto::PIPE::PIPE_M;
}
Expand Down Expand Up @@ -807,7 +807,7 @@ def TMatmulMxBiasOp : PTO_TOp<"tmatmul.mx.bias",[
}];

let extraClassDeclaration = [{
static StringRef getIntrinsicName() { return "TMATMUL_MX_BIAS"; }
static StringRef getIntrinsicName() { return "TMATMUL_MX"; }
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
Expand Down Expand Up @@ -3159,28 +3159,58 @@ def TCvtOp : PTO_TOp<"tcvt", [
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Elementwise type conversion with rounding mode (tilebuf, DPS)";
let summary = "Elementwise type conversion with optional tmp tile and saturation mode (tilebuf, DPS)";

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst,
DefaultValuedAttr<PTO_RoundModeAttr, "::mlir::pto::RoundMode::CAST_RINT">:$rmode
Optional<PTODpsType>:$tmp,
DefaultValuedAttr<PTO_RoundModeAttr, "::mlir::pto::RoundMode::CAST_RINT">:$rmode,
OptionalAttr<PTO_SaturationModeAttr>:$sat_mode
);

let results = (outs);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TRandomOp : PTO_TOp<"trandom", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Generate random values into dst tile using key/counter words.";

let arguments = (ins
AnySignlessInteger:$key0,
AnySignlessInteger:$key1,
AnySignlessInteger:$counter0,
AnySignlessInteger:$counter1,
AnySignlessInteger:$counter2,
AnySignlessInteger:$counter3,
PTODpsType:$dst,
DefaultValuedAttr<I32Attr, "10">:$rounds
);

let results = (outs);
let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src
`ins` `(` $key0 `,` $key1 `,` $counter0 `,` $counter1 `,` $counter2 `,` $counter3
attr-dict
`:` qualified(type($src)) `)`
`:` type($key0) `,` type($key1) `,` type($counter0) `,` type($counter1) `,` type($counter2) `,` type($counter3) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TDivOp : PTO_TOp<"tdiv", [
Expand Down Expand Up @@ -3955,6 +3985,7 @@ def TMrgSortOp: PTO_TOp<"tmrgsort", [
Variadic<PTODpsType>:$srcs,
Optional<AnyInteger>:$blockLen,
Variadic<PTODpsType>:$dsts,
Optional<PTODpsType>:$tmp,
Optional<AnyType>:$excuted,
DefaultValuedAttr<BoolAttr, "false">:$exhausted
);
Expand All @@ -3965,10 +3996,9 @@ def TMrgSortOp: PTO_TOp<"tmrgsort", [

let extraClassDeclaration = [{
bool isFormat1() { return getSrcs().size() == 1u && getBlockLen() && getDsts().size() == 1u; }
bool isFormat2() { return getSrcs().size() >= 2u && getSrcs().size() <= 4u && getDsts().size() == 2u && getExcuted(); }
bool isFormat2() { return getSrcs().size() >= 2u && getSrcs().size() <= 4u && getTmp() && getDsts().size() == 1u && getExcuted(); }
Value getSrc() { return getSrcs().front(); }
Value getDst() { return getDsts().front(); }
Value getTmp() { return getDsts()[1]; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstsMutable(); }
void print(::mlir::OpAsmPrinter &p);
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
Expand Down
3 changes: 3 additions & 0 deletions include/pto-c/Dialect/PTO.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOReluPreModeAttrGetValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOSaturationModeAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsASaturationModeAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED int32_t mlirPTOSaturationModeAttrGetValue(MlirAttribute attr);
// ---- Pipe attr ----
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOPipeAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAPipeAttr(MlirAttribute attr);
Expand Down
31 changes: 31 additions & 0 deletions lib/Bindings/Python/PTOModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ static void bindPTOModule(pybind11::module &m) {
.value("ODD", mlir::pto::RoundMode::ODD)
.value("CAST_RINT", mlir::pto::RoundMode::CAST_RINT);

py::enum_<mlir::pto::SaturationMode>(m, "SaturationMode")
.value("ON", mlir::pto::SaturationMode::ON)
.value("OFF", mlir::pto::SaturationMode::OFF);

py::enum_<MlirPTOCmpMode>(m, "CmpMode")
.value("EQ", MlirPTOCmpMode_EQ)
.value("NE", MlirPTOCmpMode_NE)
Expand Down Expand Up @@ -315,6 +319,33 @@ static void bindPTOModule(pybind11::module &m) {
return mlirPTORoundModeAttrGetValue(self);
});

mlir_attribute_subclass(
m, "SaturationModeAttr",
[](MlirAttribute a) { return mlirPTOAttrIsASaturationModeAttr(a); })
.def_classmethod(
"get",
[](py::object cls, py::object value, MlirContext ctx) -> py::object {
int32_t v = 0;
if (py::isinstance<py::int_>(value)) {
v = value.cast<int32_t>();
} else if (py::hasattr(value, "value")) {
v = value.attr("value").cast<int32_t>();
} else {
throw std::runtime_error("SaturationModeAttr.get expects int or SaturationMode enum");
}

MlirAttribute a = mlirPTOSaturationModeAttrGet(ctx, v);
if (mlirAttributeIsNull(a)) return py::none();
return cls.attr("__call__")(a);
},
py::arg("cls"), py::arg("value"), py::arg("context") = py::none())

.def_property_readonly(
"value",
[](MlirAttribute self) -> int32_t {
return mlirPTOSaturationModeAttrGetValue(self);
});

mlir_attribute_subclass(
m, "PipeAttr",
[](MlirAttribute a) { return mlirPTOAttrIsAPipeAttr(a); })
Expand Down
15 changes: 15 additions & 0 deletions lib/CAPI/Dialect/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr) {
return static_cast<int32_t>(a.getValue());
}

MlirAttribute mlirPTOSaturationModeAttrGet(MlirContext ctx, int32_t value) {
auto *c = unwrap(ctx);
auto mode = static_cast<mlir::pto::SaturationMode>(value);
return wrap(mlir::pto::SaturationModeAttr::get(c, mode));
}

bool mlirPTOAttrIsASaturationModeAttr(MlirAttribute attr) {
return mlir::isa<mlir::pto::SaturationModeAttr>(unwrap(attr));
}

int32_t mlirPTOSaturationModeAttrGetValue(MlirAttribute attr) {
auto a = mlir::cast<mlir::pto::SaturationModeAttr>(unwrap(attr));
return static_cast<int32_t>(a.getValue());
}

MlirAttribute mlirPTOPipeAttrGet(MlirContext ctx, int32_t value) {
auto *c = unwrap(ctx);
auto v = static_cast<mlir::pto::PIPE>(value);
Expand Down
Loading
Loading