-
Notifications
You must be signed in to change notification settings - Fork 54
[AIROCMLIR-445] Lower migraphx.backwards_data_convolution
#2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf71020
d1736a7
756c35f
b401931
8b26f6e
6928e1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -272,6 +272,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite( | |
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Forward and Backward convolution converter | ||
| //===----------------------------------------------------------------------===// | ||
| namespace { | ||
| struct ConvConverter final | ||
| : public OpConversionPattern<migraphx::ConvolutionOp> { | ||
|
|
@@ -289,6 +292,24 @@ struct ConvConverter final | |
| migraphx::ConvolutionOp op, Value input, | ||
| Value filter) const; | ||
| }; | ||
|
|
||
| struct BackwardConvConverter final | ||
| : public OpConversionPattern<migraphx::ConvolutionBwdDataOp> { | ||
| using OpConversionPattern< | ||
| migraphx::ConvolutionBwdDataOp>::OpConversionPattern; | ||
| using OpConversionPattern<migraphx::ConvolutionBwdDataOp>::getTypeConverter; | ||
| using OpAdaptor = | ||
| typename OpConversionPattern<migraphx::ConvolutionBwdDataOp>::OpAdaptor; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override; | ||
|
|
||
| private: | ||
| LogicalResult emitBackwardConv(ConversionPatternRewriter &rewriter, | ||
| migraphx::ConvolutionBwdDataOp op, Value input, | ||
| Value filter) const; | ||
| }; | ||
| } // namespace | ||
|
|
||
| // Nice helper function for the linalg.generic op region | ||
|
|
@@ -302,19 +323,20 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) { | |
| } | ||
|
|
||
| /// Emit convolution attributes on the newly created operation. | ||
| static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp, | ||
| Attribute strides, Attribute dilation, | ||
| Attribute pad, Attribute convOpName) { | ||
| static void emitConvAttributes(Value convOp, Attribute strides, | ||
| Attribute dilation, Attribute pad, | ||
| Attribute perfConfig, Attribute groupAttr, | ||
| Attribute convOpName) { | ||
| Operation *newOp = convOp.getDefiningOp(); | ||
| newOp->setAttr("pad", pad); | ||
| newOp->setAttr("group", op.getGroupAttr()); | ||
| newOp->setAttr("group", groupAttr); | ||
| newOp->setAttr("stride", strides); | ||
| newOp->setAttr("dilation", dilation); | ||
|
|
||
| // Convert optional attributes | ||
| if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config")) | ||
| newOp->setAttr("perf_config", attr); | ||
| newOp->setAttr("conv_op", convOpName); | ||
| if (perfConfig) | ||
| newOp->setAttr("perf_config", perfConfig); | ||
| newOp->setAttr(rock::linalgConvOpAttrName, convOpName); | ||
| } | ||
|
|
||
| /// Emit a grouped convolution of any spatial rank (1D, 2D, or 3D). | ||
|
|
@@ -403,6 +425,113 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc, | |
| .getResult(0); | ||
| } | ||
|
|
||
| /// Emit a grouped backward (transposed) convolution of any spatial rank. | ||
| /// Input shape: (batch, group, channel, spatial...), | ||
| /// filter shape: (group, filter, channel, kernel_spatial...) | ||
| /// | ||
| /// The loop structure mirrors the forward convolution, but with the | ||
| /// stride/dilation affine expression on the *output* indexing map: | ||
| /// | ||
| /// clang-format off | ||
| /// for n in batch: | ||
| /// for g in group: | ||
| /// for ih_0 in input_spatial_0: | ||
| /// for ih_1 in input_spatial_1: | ||
| /// // ... | ||
| /// for ih_{dim-1} in input_spatial_{dim-1}: | ||
| /// for f in filters: | ||
| /// reduction starts here | ||
| /// for c in channels: // reduction | ||
| /// for kh_0 in kernel_spatial_0: // reduction | ||
| /// for kh_1 in kernel_spatial_1: // reduction | ||
| /// // ... | ||
| /// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] += | ||
| /// input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...] | ||
| /// clang-format on | ||
| static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter, | ||
| Location loc, RankedTensorType resultType, | ||
| Value input, Value filter, Value zero, | ||
| ArrayAttr strides, ArrayAttr dilation) { | ||
| MLIRContext *ctx = rewriter.getContext(); | ||
| int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - 3; | ||
| SmallVector<int64_t, 4> strideVals; | ||
| SmallVector<int64_t, 4> dilationVals; | ||
| llvm::transform( | ||
| strides.getValue(), std::back_inserter(strideVals), | ||
| [](Attribute attr) { return cast<IntegerAttr>(attr).getInt(); }); | ||
| llvm::transform( | ||
| dilation.getValue(), std::back_inserter(dilationVals), | ||
| [](Attribute attr) { return cast<IntegerAttr>(attr).getInt(); }); | ||
|
|
||
| // Iteration domain layout (mirrors emitGroupedConv): | ||
| // parallel: batch, group, ih_0 .. ih_{dim-1}, filter | ||
| // reduction: channel, kh_0 .. kh_{dim-1} | ||
| // See the loop structure from above to see where these constants come from | ||
| const int64_t ihStart = 2; | ||
| const int64_t filterIdx = ihStart + spatialDim; | ||
| const int64_t channelIdx = filterIdx + 1; | ||
| const int64_t khStart = channelIdx + 1; | ||
| const int64_t totalDims = khStart + spatialDim; | ||
| const int64_t numParallel = channelIdx; | ||
|
|
||
| SmallVector<AffineExpr> d; | ||
| for (int64_t i = 0; i < totalDims; ++i) | ||
| d.push_back(getAffineDimExpr(i, ctx)); | ||
|
|
||
| AffineExpr batch = d[0], group = d[1]; | ||
| AffineExpr outChannel = d[filterIdx]; | ||
| AffineExpr inChannel = d[channelIdx]; | ||
|
|
||
| SmallVector<AffineExpr> inputExprs = {batch, group, inChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| inputExprs.push_back(d[ihStart + i]); | ||
|
|
||
| SmallVector<AffineExpr> filterExprs = {group, inChannel, outChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| filterExprs.push_back(d[khStart + i]); | ||
|
|
||
| SmallVector<AffineExpr> outputExprs = {batch, group, outChannel}; | ||
| for (int64_t i = 0; i < spatialDim; ++i) { | ||
| AffineExpr ih_i = d[ihStart + i]; | ||
| AffineExpr kh_i = d[khStart + i]; | ||
| outputExprs.push_back(ih_i * strideVals[i] + kh_i * dilationVals[i]); | ||
| } | ||
|
|
||
| SmallVector<AffineMap> indexingMaps = { | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx), | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx), | ||
| AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)}; | ||
|
|
||
| SmallVector<utils::IteratorType> iteratorTypes(numParallel, | ||
| utils::IteratorType::parallel); | ||
| iteratorTypes.append(totalDims - numParallel, utils::IteratorType::reduction); | ||
|
|
||
| auto result = linalg::GenericOp::create( | ||
| rewriter, loc, resultType, ValueRange{input, filter}, zero, | ||
| indexingMaps, iteratorTypes, convBodyBuilder) | ||
| .getResult(0); | ||
| return result; | ||
| } | ||
|
|
||
| /// Given the collapsed NF* result type and the group count, return the | ||
| /// expanded NGK* result type for the grouped linalg convolution. | ||
| static RankedTensorType expandResultForGroupedConv(RankedTensorType resultType, | ||
| int64_t group) { | ||
| ArrayRef<int64_t> resultShape = resultType.getShape(); | ||
| int64_t n = resultType.getDimSize(0); | ||
| int64_t newF = resultType.getDimSize(1) / group; | ||
| assert(resultType.getDimSize(1) % group == 0 && | ||
| "output channel must be divisible by group"); | ||
|
|
||
| SmallVector<int64_t, 4> newShape; | ||
| newShape.push_back(n); | ||
| newShape.push_back(group); | ||
| newShape.push_back(newF); | ||
| newShape.insert(newShape.end(), std::next(resultShape.begin(), 2), | ||
| resultShape.end()); | ||
| return RankedTensorType::get(newShape, resultType.getElementType()); | ||
| } | ||
|
|
||
| LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, | ||
| migraphx::ConvolutionOp op, Value input, | ||
| Value filter) const { | ||
|
|
@@ -444,7 +573,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter, | |
| Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter, | ||
| zero, strides, dilation); | ||
|
|
||
| emitConvAttributes(op, result, strides, dilation, op.getPaddingAttr(), | ||
| emitConvAttributes(result, strides, dilation, op.getPaddingAttr(), | ||
| op->getAttr("perf_config"), op.getGroupAttr(), | ||
| resultConvOpName); | ||
|
|
||
| // we must reshape the operand to what the type converter expects | ||
|
|
@@ -608,6 +738,116 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor, | |
| } | ||
|
|
||
| // TODO: migraphx::DeQuantizeLinearConverter | ||
| LogicalResult | ||
| BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter, | ||
| migraphx::ConvolutionBwdDataOp op, | ||
| Value input, Value filter) const { | ||
| Location loc = op.getLoc(); | ||
| int64_t group = op.getGroupAttr().getInt(); | ||
| int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - | ||
| 3; // exclude batch (N), group (G), channel (C) | ||
| if (spatialDim > 3) | ||
| return op.emitError("only support 1D to 3D conv_bwd"); | ||
|
|
||
| // To get the result shape, we must first add the padding | ||
| ArrayRef<Attribute> padding = op.getPaddingAttr().getValue(); | ||
| RankedTensorType originalResult = | ||
| cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult())); | ||
| SmallVector<int64_t, 4> resultShape(originalResult.getShape()); | ||
| SmallVector<int64_t, 4> lowPads; | ||
| SmallVector<int64_t, 4> highPads; | ||
| for (int64_t i = 0; i < spatialDim; ++i) { | ||
| int64_t lowPad = cast<IntegerAttr>(padding[i]).getInt(); | ||
| int64_t highPad = cast<IntegerAttr>(padding[i + spatialDim]).getInt(); | ||
| // The first two dimension of the result is batch and channel, and we apply | ||
| // padding to the spatial dimension | ||
| resultShape[2 + i] += lowPad + highPad; | ||
| lowPads.push_back(lowPad); | ||
| highPads.push_back(highPad); | ||
| } | ||
| RankedTensorType resultType = | ||
| RankedTensorType::get(resultShape, originalResult.getElementType()); | ||
| auto newResultType = expandResultForGroupedConv(resultType, group); | ||
| Value zero = arith::ConstantOp::create(rewriter, loc, newResultType, | ||
| rewriter.getZeroAttr(newResultType)); | ||
|
|
||
| ArrayAttr strides = op.getStride(); | ||
| ArrayAttr dilation = op.getDilation(); | ||
|
|
||
| Value result = emitGroupedBackwardConv(rewriter, loc, newResultType, input, | ||
| filter, zero, strides, dilation); | ||
| rock::LinalgConvType convType = | ||
| (spatialDim == 3) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd | ||
| : (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw | ||
| : rock::LinalgConvType::Conv1dBWDNgchGckh; | ||
| emitConvAttributes( | ||
| result, strides, dilation, op.getPaddingAttr(), | ||
| op->getAttr("perf_config"), op.getGroupAttr(), | ||
| rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType)); | ||
|
|
||
| // Collapse result from NGK* back to NK* | ||
| SmallVector<ReassociationIndices, 4> reassociation{{0}, {1, 2}}; | ||
| llvm::for_each(llvm::seq<int64_t>(3, spatialDim + 3), | ||
| [&](int64_t index) { reassociation.push_back({index}); }); | ||
| auto finalResult = | ||
| tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation) | ||
| .getResult(); | ||
|
|
||
| bool hasPadding = llvm::any_of(lowPads, [](int64_t p) { return p != 0; }) || | ||
| llvm::any_of(highPads, [](int64_t p) { return p != 0; }); | ||
| if (hasPadding) { | ||
| int64_t rank = originalResult.getRank(); | ||
| SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); | ||
| SmallVector<OpFoldResult> sizes; | ||
| SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); | ||
| for (int64_t i = 0; i < rank; ++i) | ||
| sizes.push_back(rewriter.getIndexAttr(originalResult.getDimSize(i))); | ||
| for (int64_t i = 0; i < spatialDim; ++i) | ||
| offsets[2 + i] = rewriter.getIndexAttr(lowPads[i]); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember that conv_bwd can probably have negative padding values. Can you check and would this work for that case ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this would work if we are having a negative padding. This is because What is the semantics for negative padding? Is it just the same as the as we apply a output padding with with the magnitude being those negative values? |
||
| finalResult = | ||
| tensor::ExtractSliceOp::create(rewriter, loc, originalResult, | ||
| finalResult, offsets, sizes, strides) | ||
| .getResult(); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, finalResult); | ||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult BackwardConvConverter::matchAndRewrite( | ||
| migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| // Backward convolution lowering is similar to forward convolution and is | ||
| // lowered in three steps: | ||
| // 1. Expand the channel dimension into (group, channel_per_group), | ||
| // introducing | ||
| // a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and | ||
| // filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr. | ||
| // 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the | ||
| // result back to the original NFHW/NFDHW shape for the type converter. | ||
| Location loc = op.getLoc(); | ||
| Value input = adaptor.getInput(); | ||
| Value filter = adaptor.getFilter(); | ||
| RankedTensorType inputType = cast<RankedTensorType>(input.getType()); | ||
| int64_t dim = inputType.getRank() - 2; | ||
| int64_t group = op.getGroupAttr().getInt(); | ||
|
|
||
| if (dim > 3 || dim < 1) { | ||
| return op.emitError(Twine(dim) + "D conv is not supported for now"); | ||
| } | ||
|
|
||
| if (inputType.getElementType() != op.getFilter().getType().getElementType() || | ||
| inputType.getElementType() != op.getResult().getType().getElementType()) { | ||
| return op.emitError( | ||
| "type casting between operands and result is unsupported for now"); | ||
| } | ||
|
|
||
| input = expandGroupDim(rewriter, loc, input, /*isFilter=*/false, group, dim); | ||
| filter = expandGroupDim(rewriter, loc, filter, /*isFilter=*/true, group, dim); | ||
|
|
||
| return emitBackwardConv(rewriter, op, input, filter); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Base kernels (gemm) | ||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -1583,8 +1823,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns( | |
| LiteralConverter, ReshapeConverter, | ||
| BooleanElementwiseConverter<migraphx::Greater>, | ||
| BooleanElementwiseConverter<migraphx::Equal>, ClipConverter, | ||
| TransposeConverter, ConvConverter, SliceConverter>( | ||
| converter, patterns.getContext()); | ||
| TransposeConverter, ConvConverter, SliceConverter, | ||
| BackwardConvConverter>(converter, patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| // RUN: rocmlir-opt -split-input-file --migraphx-to-linalg --canonicalize --cse --remove-dead-values %s | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: func.func @mlir_bwd_data_conv( | ||
| // CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}}) | ||
| // CHECK-DAG: %[[cst:.*]] = arith.constant | ||
| // CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg0]] | ||
| // CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg1]] | ||
| // CHECK-DAG: %[[conv:.*]] = linalg.generic {{.*}} ins(%[[expanded]], %[[expanded_0]] : tensor{{.*}}) outs(%[[cst]] : tensor{{.*}}) | ||
| // CHECK-SAME: attrs = {conv_op = #rock<LinalgConvType convbwd2d_ngchw_gckhw>, dilation = [1, 1], group = 1 : i64, pad = [1, 1, 1, 1], stride = [2, 3]} | ||
| // CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[conv]] | ||
| // CHECK-DAG: %[[extracted_slice:.*]] = tensor.extract_slice %[[collapsed]] | ||
| // CHECK-DAG: %[[collapsed_1:.*]] = tensor.collapse_shape %[[extracted_slice]] | ||
| // CHECK-DAG: return %[[collapsed_1]] | ||
| func.func @mlir_bwd_data_conv( | ||
| %arg0: !migraphx.shaped<1x3x6x7xf32, 126x42x7x1>, | ||
| %arg1: !migraphx.shaped<3x4x3x3xf32, 36x9x3x1> | ||
| ) -> !migraphx.shaped<1x4x11x19xf32, 836x209x19x1> { | ||
| %0 = migraphx.backwards_data_convolution %arg0, %arg1 { | ||
| dilation = [1, 1], | ||
| group = 1 : i64, | ||
| padding = [1, 1, 1, 1], | ||
| padding_mode = 0 : i64, | ||
| stride = [2, 3]} : <1x3x6x7xf32, 126x42x7x1>, <3x4x3x3xf32, 36x9x3x1> -> <1x4x11x19xf32, 836x209x19x1> | ||
| return %0 : !migraphx.shaped<1x4x11x19xf32, 836x209x19x1> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Output grad: NCDHW = 1x1x1x3x3, Filter: CKDHW = 1x1x1x3x3 | ||
| // stride=[1,1,1], dilation=[1,1,1], padding=[0,0,0,0,0,0], group=1 | ||
| // CHECK-LABEL: func.func @mlir_bwd_data_conv( | ||
| // CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}}) | ||
| // CHECK-DAG: %[[cst:.*]] = arith.constant | ||
| // CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg1]] | ||
| // CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg0]] | ||
| // CHECK-DAG: %[[conv:.*]] = linalg.generic {{.*}} ins(%[[expanded]], %[[expanded_0]] : tensor{{.*}}) outs(%[[cst]] : tensor{{.*}}) | ||
| // CHECK-SAME: attrs = {conv_op = #rock<LinalgConvType convbwd3d_ngchwd_gckhwd>, dilation = [1, 1, 1], group = 1 : i64, pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} | ||
| // CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[conv]] | ||
| // CHECK-DAG: return %[[collapsed]] | ||
| func.func @mlir_bwd_data_conv( | ||
| %arg0: !migraphx.shaped<1x1x1x3x3xf32, 9x9x9x3x1>, | ||
| %arg1: !migraphx.shaped<1x1x1x3x3xf32, 9x9x9x3x1> | ||
| ) -> !migraphx.shaped<1x1x1x5x5xf32, 25x25x25x5x1> attributes {rock.arch = "##TOKEN_ARCH##", rock.kernel} { | ||
| %0 = migraphx.backwards_data_convolution %arg1, %arg0 { | ||
| dilation = [1, 1, 1], | ||
| group = 1 : i64, | ||
| padding = [0, 0, 0, 0, 0, 0], | ||
| padding_mode = 0 : i64, | ||
| stride = [1, 1, 1] | ||
| } : <1x1x1x3x3xf32, 9x9x9x3x1>, <1x1x1x3x3xf32, 9x9x9x3x1> -> <1x1x1x5x5xf32, 25x25x25x5x1> | ||
| return %0 : !migraphx.shaped<1x1x1x5x5xf32, 25x25x25x5x1> | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.