Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/fusilli/attributes/pointwise_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace fusilli {
OP(GELU_APPROX_TANH_FWD) \
/* OP(GELU_BWD) */ \
OP(GELU_FWD) \
/* OP(GEN_INDEX) */ \
OP(GEN_INDEX) \
OP(IDENTITY) \
OP(LOG) \
OP(LOGICAL_AND) \
Expand Down Expand Up @@ -116,6 +116,11 @@ class PointwiseAttr : public AttributesCRTP<PointwiseAttr> {
return *this;
}

PointwiseAttr &setGenIdxAxis(int64_t axis) {
genIdxAxis_ = axis;
return *this;
}

// Getters:
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_0)
FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, IN_1)
Expand All @@ -126,6 +131,7 @@ class PointwiseAttr : public AttributesCRTP<PointwiseAttr> {
float getEluAlpha() const { return eluAlpha_; }
float getSoftplusBeta() const { return softplusBeta_; }
float getSoftplusThreshold() const { return softplusThreshold_; }
int64_t getGenIdxAxis() const { return genIdxAxis_; }

// Utilities for pointwise modes.
static const std::unordered_map<Mode, std::string> kModeToStr;
Expand All @@ -137,6 +143,7 @@ class PointwiseAttr : public AttributesCRTP<PointwiseAttr> {
float eluAlpha_ = 1.0f;
float softplusBeta_ = 1.0f;
float softplusThreshold_ = 20.0f;
int64_t genIdxAxis_ = 0;
};

#define FUSILLI_DECLARE_STRINGIFY_POINTWISE_MODE(mode) \
Expand Down Expand Up @@ -165,6 +172,7 @@ inline const std::unordered_map<PointwiseAttr::Mode, int>
{PointwiseAttr::Mode::FLOOR, 1},
{PointwiseAttr::Mode::GELU_APPROX_TANH_FWD, 1},
{PointwiseAttr::Mode::GELU_FWD, 1},
{PointwiseAttr::Mode::GEN_INDEX, 1},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit strange that it takes an input that it never uses. Is this just for shape inference?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct — IN_0 is only used for shape inference. The Graph::pointwise single-operand overload (graph.h:937) feeds the input's shape through inferPropertiesNode so the OUT_0 dims/strides can be derived without having to set them explicitly. This matches cuDNN's Graph API shape for GEN_INDEX (it takes an input tensor whose shape drives the output shape). The emitted linalg.generic doesn't consume IN_0's values.

{PointwiseAttr::Mode::LOG, 1},
{PointwiseAttr::Mode::LOGICAL_AND, 2},
{PointwiseAttr::Mode::LOGICAL_NOT, 1},
Expand Down
76 changes: 59 additions & 17 deletions include/fusilli/attributes/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,40 @@

namespace fusilli {

// Category of a DataType. Boolean is treated as integer since it is stored as
// i1 and participates in integer arithmetic / index_cast.
enum class DataTypeCategory : uint8_t {
Float,
Integer,
};

// Define a macro to iterate over all fusilli datatypes and the corresponding
// torch datatypes and mlir asm.
// torch datatypes, mlir asm, and category.
#define FUSILLI_FORALL_DATA_TYPES(_) \
_(Half, Half, "f16") \
_(BFloat16, BFloat16, "bf16") \
_(Float, Float, "f32") \
_(Double, Double, "f64") \
_(Uint8, Byte, "ui8") \
_(Int4, Undefined, "si4") \
_(Int8, Char, "si8") \
_(Int16, Short, "si16") \
_(Int32, Int, "si32") \
_(Int64, Long, "si64") \
_(Boolean, Bool, "i1") \
_(FP8E5M2, Float8_e5m2, "f8E5M2")
_(Half, Half, "f16", Float) \
_(BFloat16, BFloat16, "bf16", Float) \
_(Float, Float, "f32", Float) \
_(Double, Double, "f64", Float) \
_(Uint8, Byte, "ui8", Integer) \
_(Int4, Undefined, "si4", Integer) \
_(Int8, Char, "si8", Integer) \
_(Int16, Short, "si16", Integer) \
_(Int32, Int, "si32", Integer) \
_(Int64, Long, "si64", Integer) \
_(Boolean, Bool, "i1", Integer) \
_(FP8E5M2, Float8_e5m2, "f8E5M2", Float)

enum class DataType : uint8_t {
NotSet,
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) FUSILLI_TYPE,
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) FUSILLI_TYPE,
FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM)
#undef DEFINE_ENUM
};

// Map from Fusilli types to MLIR types.
static const std::unordered_map<DataType, std::string> kDataTypeToMlirTypeAsm =
{
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \
{DataType::FUSILLI_TYPE, MLIR_TYPE},
FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM)
#undef DEFINE_ENUM
Expand All @@ -57,7 +64,7 @@ static const std::unordered_map<DataType, std::string> kDataTypeToMlirTypeAsm =
// Map from Fusilli types to Torch types.
static const std::unordered_map<DataType, torch_upstream::ScalarType>
kDataTypeToTorchType = {
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \
{DataType::FUSILLI_TYPE, torch_upstream::ScalarType::TORCH_TYPE},
FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM)
#undef DEFINE_ENUM
Expand All @@ -66,12 +73,47 @@ static const std::unordered_map<DataType, torch_upstream::ScalarType>
// Map from MLIR type ASM strings to Fusilli types.
static const std::unordered_map<std::string, DataType> kMlirTypeAsmToDataType =
{
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE) \
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \
{MLIR_TYPE, DataType::FUSILLI_TYPE},
FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM)
#undef DEFINE_ENUM
};

// Map from Fusilli types to their category (float vs integer).
static const std::unordered_map<DataType, DataTypeCategory> kDataTypeCategory =
{
#define DEFINE_ENUM(FUSILLI_TYPE, TORCH_TYPE, MLIR_TYPE, CATEGORY) \
{DataType::FUSILLI_TYPE, DataTypeCategory::CATEGORY},
FUSILLI_FORALL_DATA_TYPES(DEFINE_ENUM)
#undef DEFINE_ENUM
};

// Returns true iff `dtype` is a floating-point type (e.g. f16, bf16, f32).
inline bool isFloatDataType(DataType dtype) {
auto it = kDataTypeCategory.find(dtype);
return it != kDataTypeCategory.end() && it->second == DataTypeCategory::Float;
}

// Returns true iff `dtype` is an integer type (including Boolean, which is
// stored as i1).
inline bool isIntegerDataType(DataType dtype) {
auto it = kDataTypeCategory.find(dtype);
return it != kDataTypeCategory.end() &&
it->second == DataTypeCategory::Integer;
}

// Returns the signless MLIR element type for a DataType, stripping the
// leading "s"/"u" that kDataTypeToMlirTypeAsm carries on integer types
// ("si32" -> "i32", "ui8" -> "i8"). Builtin MLIR tensors and arith ops use
// signless integers, unlike torch vtensors which preserve signedness.
inline std::string getSignlessElementTypeAsm(DataType dtype) {
std::string elemType = kDataTypeToMlirTypeAsm.at(dtype);
if (elemType.size() >= 2 &&
(elemType.substr(0, 2) == "si" || elemType.substr(0, 2) == "ui"))
elemType = "i" + elemType.substr(2);
return elemType;
}

} // namespace fusilli

#endif // FUSILLI_ATTRIBUTES_TYPES_H
22 changes: 22 additions & 0 deletions include/fusilli/node/pointwise_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,28 @@ class PointwiseNode : public NodeCRTP<PointwiseNode> {
return ok();
}

ErrorObject postValidateNode() const override final {
FUSILLI_LOG_LABEL_ENDL("INFO: Post-Validating PointwiseNode '"
<< pointwiseAttr.getName() << "'");

if (pointwiseAttr.getMode() == PointwiseAttr::Mode::GEN_INDEX) {
const auto &out = pointwiseAttr.getOUT_0();
const int64_t axis = pointwiseAttr.getGenIdxAxis();
const int64_t outRank = static_cast<int64_t>(out->getDim().size());
FUSILLI_RETURN_ERROR_IF(
axis < 0 || axis >= outRank, ErrorCode::InvalidAttribute,
"GEN_INDEX axis " + std::to_string(axis) +
" is out of range for output of rank " + std::to_string(outRank));
const DataType outDtype = out->getDataType();
FUSILLI_RETURN_ERROR_IF(
!isFloatDataType(outDtype) && !isIntegerDataType(outDtype),
ErrorCode::InvalidAttribute,
"GEN_INDEX only supports integer or float output dtypes");
}

return ok();
}

ErrorObject inferPropertiesNode() override final {
FUSILLI_LOG_LABEL_ENDL("INFO: Inferring properties for PointwiseNode '"
<< pointwiseAttr.getName() << "'");
Expand Down
118 changes: 118 additions & 0 deletions include/fusilli/support/asm_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <cstdint>
#include <format> // C++20
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -128,6 +129,47 @@ inline std::string buildTensorTypeStr(const std::vector<int64_t> &dims,
return oss.str();
}

// Builds a builtin MLIR tensor type string (e.g. "tensor<16x256xf32>") from
// explicit dims and dtype. Uses the signless element type so the result
// bridges cleanly with torch vtensors via torch_c.from_builtin_tensor.
inline std::string buildBuiltinTensorTypeStr(const std::vector<int64_t> &dims,
DataType dtype) {
std::ostringstream oss;
oss << "tensor<";
for (int64_t dim : dims)
oss << dim << "x";
oss << getSignlessElementTypeAsm(dtype) << ">";
return oss.str();
}

// Builds an identity affine_map over `rank` dimensions, e.g.
// rank=3 → "affine_map<(d0, d1, d2) -> (d0, d1, d2)>"
// Used for linalg.generic indexing_maps where every operand/result accesses
// the full iteration space in natural order.
inline std::string getIdentityAffineMapAsm(size_t rank) {
std::ostringstream dims;
std::vector<size_t> indices(rank);
std::iota(indices.begin(), indices.end(), 0);
interleave(
indices.begin(), indices.end(), [&](size_t i) { dims << "d" << i; },
[&] { dims << ", "; });
return std::format("affine_map<({0}) -> ({0})>", dims.str());
}

// Builds a comma-separated list of `rank` linalg iterator type attrs all of
// the given kind (e.g. "parallel" or "reduction"), e.g.
// rank=3, kind="parallel" → "\"parallel\", \"parallel\", \"parallel\""
// Paired with an indexing_maps attribute on a linalg.generic.
inline std::string getIteratorTypesAsm(size_t rank, std::string_view kind) {
std::ostringstream oss;
std::vector<size_t> indices(rank);
std::iota(indices.begin(), indices.end(), 0);
interleave(
indices.begin(), indices.end(),
[&](size_t) { oss << "\"" << kind << "\""; }, [&] { oss << ", "; });
return oss.str();
}

// ---------------------------------------------------------------------------
// Torch IR constant helpers
//
Expand Down Expand Up @@ -1894,6 +1936,82 @@ inline std::string PointwiseNode::emitNodePreAsm() const {
FUSILLI_DECLARE_SUB_ADD_TORCH_EMITTER(ADD, torch.aten.add.Tensor)
FUSILLI_DECLARE_SUB_ADD_TORCH_EMITTER(SUB, torch.aten.sub.Tensor)

case PointwiseAttr::Mode::GEN_INDEX: {
const auto &out0 = pointwiseAttr.getOUT_0();
const std::string suffix = getName();
const std::vector<int64_t> &outDims = out0->getDim();
const int64_t axis = pointwiseAttr.getGenIdxAxis();
// Range and dtype correctness are enforced in postValidateNode; these
// asserts are safety nets in case emitNodePreAsm is called without a
// prior graph->validate().
assert(axis >= 0 && axis < static_cast<int64_t>(outDims.size()) &&
"GEN_INDEX axis out of range");
Comment thread
rsuderman marked this conversation as resolved.

const size_t rank = outDims.size();
const DataType outDtype = out0->getDataType();
const bool isFloat = isFloatDataType(outDtype);
assert((isFloat || isIntegerDataType(outDtype)) &&
"GEN_INDEX only supports integer or float output dtypes");
const std::string signlessDtype = getSignlessElementTypeAsm(outDtype);

const std::string builtinTensorType =
buildBuiltinTensorTypeStr(outDims, out0->getDataType());
const std::string vtensorType = out0->getTensorTypeAsm(
/*isValueTensor=*/true, /*useLogicalDims=*/true);

const std::string indexingMap = getIdentityAffineMapAsm(rank);
const std::string iterators = getIteratorTypesAsm(rank, "parallel");

// One full schema per cast flavour: integer outputs take a single
// index_cast, floats go through i64 and then sitofp. Keeping the
// linalg.generic and the cast together lets each variant be read as a
// cohesive block of IR.
constexpr std::string_view kGenIndexIntSchema = R"(
%gen_index_empty_{0} = tensor.empty() : {1}
%gen_index_linalg_{0} = linalg.generic {{indexing_maps = [{2}], iterator_types = [{3}]}} outs(%gen_index_empty_{0} : {1}) {{
^bb0(%gen_index_out_{0}: {4}):
%gen_index_idx_{0} = linalg.index {5} : index
%gen_index_val_{0} = arith.index_cast %gen_index_idx_{0} : index to {4}
linalg.yield %gen_index_val_{0} : {4}
}} -> {1}
{6} = torch_c.from_builtin_tensor %gen_index_linalg_{0} : {1} -> {7}
{8})";
Comment thread
rsuderman marked this conversation as resolved.

constexpr std::string_view kGenIndexFloatSchema = R"(
%gen_index_empty_{0} = tensor.empty() : {1}
%gen_index_linalg_{0} = linalg.generic {{indexing_maps = [{2}], iterator_types = [{3}]}} outs(%gen_index_empty_{0} : {1}) {{
^bb0(%gen_index_out_{0}: {4}):
%gen_index_idx_{0} = linalg.index {5} : index
%gen_index_int_{0} = arith.index_cast %gen_index_idx_{0} : index to i64
%gen_index_val_{0} = arith.sitofp %gen_index_int_{0} : i64 to {4}
linalg.yield %gen_index_val_{0} : {4}
}} -> {1}
{6} = torch_c.from_builtin_tensor %gen_index_linalg_{0} : {1} -> {7}
{8})";

if (isFloat)
return std::format(kGenIndexFloatSchema, suffix, /* {0} */
builtinTensorType, /* {1} */
indexingMap, /* {2} */
iterators, /* {3} */
signlessDtype, /* {4} */
axis, /* {5} */
getResultNamesAsm(), /* {6} */
vtensorType, /* {7} */
permuteOUT0 /* {8} */
);
return std::format(kGenIndexIntSchema, suffix, /* {0} */
builtinTensorType, /* {1} */
indexingMap, /* {2} */
iterators, /* {3} */
signlessDtype, /* {4} */
axis, /* {5} */
getResultNamesAsm(), /* {6} */
vtensorType, /* {7} */
permuteOUT0 /* {8} */
);
}

default:
assert(false && "Unsupported pointwise mode");
return "";
Expand Down
65 changes: 65 additions & 0 deletions samples/pointwise/pointwise_unary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,68 @@ TEST_CASE("Pointwise unary ops", "[pointwise][graph]") {
if (supportsFloat(mode) && supportsNegative(mode))
execute(handle, DataType::Half, half(-3.14));
}

// GEN_INDEX is a unary pointwise op but the expected value varies per output
// position (not a single scalar), so it doesn't fit the test harness above.
// Keep it in this file so all unary pointwise ops are exercised together.
TEST_CASE("Pointwise gen_index", "[pointwise][graph]") {
const auto dim = std::vector<int64_t>{2, 3, 4, 5};
const int64_t axis = GENERATE(int64_t(0), int64_t(1), int64_t(2), int64_t(3));

auto buildNewGraph = [&](Handle &handle, DataType dt) {
auto graph = std::make_shared<Graph>();
graph->setName(std::format("pointwise_gen_index_dt{}_axis{}",
kDataTypeToMlirTypeAsm.at(dt), axis));
graph->setIODataType(dt).setComputeDataType(dt);

auto xT = graph->tensor(TensorAttr().setName("in0").setDim(dim).setStride(
generateStrideFromDim(dim, getContiguousStrideOrder(dim.size()))));

auto pointwiseAttr = PointwiseAttr()
.setMode(PointwiseAttr::Mode::GEN_INDEX)
.setGenIdxAxis(axis);
auto pointwiseResult = graph->pointwise(xT, pointwiseAttr);

pointwiseResult->setName("result").setOutput(true);

FUSILLI_REQUIRE_OK(graph->validate());
FUSILLI_REQUIRE_OK(graph->compile(handle, /*remove=*/true));

return std::make_tuple(graph, xT, pointwiseResult);
};

auto execute = [&]<typename T>(Handle &handle, DataType dt) {
auto [graph, xT, yT] = buildNewGraph(handle, dt);

FUSILLI_REQUIRE_ASSIGN(auto xBuf,
allocateBufferOfType(handle, xT, dt, T(0)));
FUSILLI_REQUIRE_ASSIGN(auto yBuf,
allocateBufferOfType(handle, yT, dt, T(0)));

const std::unordered_map<std::shared_ptr<TensorAttr>,
std::shared_ptr<Buffer>>
variantPack = {{xT, xBuf}, {yT, yBuf}};

FUSILLI_REQUIRE_ASSIGN(
auto workspace, allocateWorkspace(handle, graph->getWorkspaceSize()));
FUSILLI_REQUIRE_OK(graph->execute(handle, variantPack, workspace));

std::vector<T> result;
FUSILLI_REQUIRE_OK(yBuf->read(handle, result));

// Expected value at flat index `i` is the coordinate along `axis`.
const size_t rank = dim.size();
std::vector<int64_t> strides(rank, 1);
for (int64_t k = static_cast<int64_t>(rank) - 2; k >= 0; --k)
strides[k] = strides[k + 1] * dim[k + 1];

for (size_t i = 0; i < result.size(); ++i) {
int64_t coord = (static_cast<int64_t>(i) / strides[axis]) % dim[axis];
REQUIRE(result[i] == static_cast<T>(coord));
}
};

FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
execute.template operator()<float>(handle, DataType::Float);
execute.template operator()<int32_t>(handle, DataType::Int32);
}
Loading
Loading