Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
4efdd1a
Add channelwise conv
pfultz2 Feb 16, 2026
a0c6b07
Format
pfultz2 Feb 16, 2026
efeafca
Use shared memory
pfultz2 Feb 16, 2026
4498934
Format
pfultz2 Feb 16, 2026
1792edb
Update slice functions
pfultz2 Feb 16, 2026
0304972
Format
pfultz2 Feb 16, 2026
1389ae5
Update to use slices instead
pfultz2 Feb 16, 2026
9c9b9a5
Format
pfultz2 Feb 16, 2026
207e5d6
Add reduce_schedule for outer batches
pfultz2 Feb 16, 2026
cdae8f4
Format
pfultz2 Feb 16, 2026
b51b82f
Use pooling_reduce
pfultz2 Feb 16, 2026
b5f4f0f
Format
pfultz2 Feb 16, 2026
15fd39f
Some refactoring to use tiling
pfultz2 Feb 16, 2026
b61daa3
FOrmat
pfultz2 Feb 16, 2026
c9d258f
Access directly
pfultz2 Feb 16, 2026
6d979f5
Format
pfultz2 Feb 16, 2026
ecbce52
Add join
pfultz2 Feb 16, 2026
4bd6556
Update tuning
pfultz2 Feb 16, 2026
d1da333
Format
pfultz2 Feb 16, 2026
9cc6906
Add multi-output
pfultz2 Feb 17, 2026
0942c87
Format
pfultz2 Feb 17, 2026
ca147d2
Add spatial tiler
pfultz2 Feb 17, 2026
3b17a09
Format
pfultz2 Feb 17, 2026
037d10f
Avoid bounds check when there is no padding
pfultz2 Feb 17, 2026
7bc6d78
Remove lines
pfultz2 Feb 17, 2026
e3077b8
Use functions instead of variables
pfultz2 Feb 17, 2026
414aab4
Format
pfultz2 Feb 17, 2026
e56c4f1
Inine methods
pfultz2 Feb 17, 2026
b51c74f
Format
pfultz2 Feb 17, 2026
3d4bfe4
Update quick tuning list
pfultz2 Feb 17, 2026
a362a19
Format
pfultz2 Feb 17, 2026
208c7ad
Add another config
pfultz2 Feb 18, 2026
f2daa29
Add more configs
pfultz2 Feb 18, 2026
36110cf
Format
pfultz2 Feb 18, 2026
882fe3b
Add pointwise fusion
pfultz2 Mar 2, 2026
24a2645
Format
pfultz2 Mar 2, 2026
28e32af
Only enable for float and navi
pfultz2 Mar 2, 2026
e35373c
Format
pfultz2 Mar 2, 2026
f69d9bb
Fix tidy
pfultz2 Mar 2, 2026
fb48be7
Format
pfultz2 Mar 2, 2026
ef923a8
Fix tidy
pfultz2 Mar 2, 2026
513fafc
Update year
pfultz2 Mar 2, 2026
ec3c657
Fix cppcheck
pfultz2 Mar 2, 2026
5d8051b
Format
pfultz2 Mar 2, 2026
99c896c
Use std algos
pfultz2 Mar 2, 2026
9f0903d
Format
pfultz2 Mar 2, 2026
680328b
Move in_bounds function
pfultz2 Mar 2, 2026
1120309
Rename type
pfultz2 Mar 2, 2026
7645792
Format
pfultz2 Mar 2, 2026
32b5894
Fix compilation failure
pfultz2 Mar 2, 2026
2141264
Format
pfultz2 Mar 2, 2026
19cf173
Simplify some more
pfultz2 Mar 2, 2026
b39416e
Format
pfultz2 Mar 2, 2026
6c990fd
Use std::transform
pfultz2 Mar 2, 2026
90638f8
Precompute slices
pfultz2 Mar 2, 2026
053bf4f
Format
pfultz2 Mar 2, 2026
ffaa5c3
Update src/targets/gpu/kernels/include/migraphx/kernels/slice.hpp
pfultz2 Mar 2, 2026
8a06baf
Change the navi check
pfultz2 Mar 2, 2026
a3fd388
Merge branch 'channelwise-conv2' of github.com:ROCmSoftwarePlatform/A…
pfultz2 Mar 2, 2026
258af41
Split verify classes
pfultz2 Mar 2, 2026
bcd468d
Revert the reduce and index changes
pfultz2 Mar 2, 2026
7ba2cca
Revert pooling changes
pfultz2 Mar 2, 2026
61f6ffb
Use signed integer
pfultz2 Mar 2, 2026
2a770dd
Merge branch 'develop' into channelwise-conv2
pfultz2 Mar 2, 2026
b5cad75
Update year
pfultz2 Mar 2, 2026
5b49459
Format
pfultz2 Mar 2, 2026
dc7f7e5
Fix merge conflicts
pfultz2 Mar 3, 2026
9eb50da
Merge branch 'develop' into channelwise-conv2
TedThemistokleous Mar 9, 2026
18a7efa
Support padding
pfultz2 Apr 3, 2026
c23a8e8
Format
pfultz2 Apr 3, 2026
747292c
Fix selection
pfultz2 Apr 3, 2026
ad9b8d1
Fix padding
pfultz2 Apr 4, 2026
77dac35
Cleanup
pfultz2 Apr 4, 2026
3c3e0ac
Merge
pfultz2 Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,37 @@ struct find_layernorm_pointwise
}
};

struct find_channelwise_conv_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
match::not_tuple(),
match::arg(0)(precompile_name("gpu::channelwise_conv").bind("channelwise_conv")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto pw_ins = r.result;
auto channelwise_ins = r.instructions["channelwise_conv"];
if(not channelwise_ins->module_inputs().empty())
return;
auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto cw_pos = std::find(pw_inputs.begin(), pw_inputs.end(), channelwise_ins);
assert(cw_pos != pw_inputs.end());
pw_inputs.erase(cw_pos);
auto inputs = channelwise_ins->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());

auto cw_op_val = channelwise_ins->get_operator().to_value();
cw_op_val["output_shape"] = to_value(pw_ins->get_shape());

m.replace_instruction(pw_ins, make_op(channelwise_ins->name(), cw_op_val), inputs, {pm});
}
};

struct find_concat_pointwise
{
auto matcher() const
Expand Down Expand Up @@ -1069,6 +1100,7 @@ void fuse_ops::apply(module& m) const
#endif
match::find_matches(m,
find_layernorm_pointwise{},
find_channelwise_conv_pointwise{},
find_concat_pointwise{},
find_contiguous_transpose_rocblas_gemm{},
#if MIGRAPHX_USE_HIPBLASLT
Expand Down
5 changes: 4 additions & 1 deletion src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -34,8 +34,11 @@ struct module_pass_manager;

namespace gpu {

struct context;

struct MIGRAPHX_GPU_EXPORT prefuse_ops
{
context* ctx = nullptr;
bool enable_attention = false;
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module_pass_manager& mpm) const;
Expand Down
199 changes: 199 additions & 0 deletions src/targets/gpu/jit/channelwise_conv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

// NOLINTNEXTLINE
static const char* const channelwise_conv_kernel = R"__migraphx__(
#include <migraphx/kernels/channelwise_conv.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>

namespace migraphx {

${preamble}

extern "C" {

MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto output, auto x, auto w, auto... inputs) {
channelwise_conv<index_ints<${tile}>, ${ntiles}>(index_ints<${tile}>{}, index_ints<${padding}>{}, ${post}, output, x, w, inputs...);
});
}

}

} // namespace migraphx

)__migraphx__";

struct channelwise_conv_compiler : compiler<channelwise_conv_compiler>
{
std::vector<std::string> names() const { return {"gpu::channelwise_conv", "channelwise_conv"}; }

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto num_spatial = v.at("num_spatial").to<std::size_t>();
const auto& out_s = inputs.back();
options.inputs = inputs;
options.output = out_s;
options.kernel_name = v.get("kernel", std::string{"channelwise_conv_kernel"});
options.virtual_inputs = inputs;

const auto& out_lens = out_s.lens();

// Thread block tile dimensions
std::vector<std::size_t> tile_sizes(num_spatial, 1);
if(num_spatial == 1)
{
tile_sizes[0] = v.get("tile_w", std::size_t{256});
}
else
{
tile_sizes[0] = v.get("tile_h", std::size_t{8});
tile_sizes[num_spatial - 1] = v.get("tile_w", std::size_t{32});
}

// Outputs per lane along W (last spatial dim)
auto noutputs = v.get("noutputs", std::size_t{4});

// Output tile = lane tile with last dim scaled by noutputs
std::vector<std::size_t> output_tile_sizes = tile_sizes;
output_tile_sizes.back() *= noutputs;

std::size_t block_size = std::accumulate(
tile_sizes.begin(), tile_sizes.end(), std::size_t{1}, std::multiplies<>());

// Blocks: N * C_out * prod(ceil(out_spatial / output_tile))
auto num_blocks = std::inner_product(
out_lens.begin() + 2,
out_lens.end(),
output_tile_sizes.begin(),
out_lens[0] * out_lens[1],
std::multiplies<>{},
[](auto out_spatial, auto tile) { return (out_spatial + tile - 1) / tile; });

options.set_launch_params(v, num_blocks * block_size, block_size);

auto padding = v.get("padding", std::vector<std::size_t>{});
if(padding.size() < 2 * num_spatial)
padding.resize(2 * num_spatial, 0);

auto src = interpolate_string(channelwise_conv_kernel,
{{"tile", to_string_range(tile_sizes)},
{"ntiles", std::to_string(noutputs)},
{"padding", to_string_range(padding)},
{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}});

return compile_hip_code_object(ctx, src, options);
}

compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
{
auto v = op.to_value();
for(const auto& x : solution)
v.insert(x);
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_channelwise_conv");
v["post"] = "MIGRAPHX_LIFT(post_channelwise_conv)";
v["kernel"] = "channelwise_conv_" + generate_name_from_ops(*pm) + "_kernel";
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}

optional<tuning_config> get_tuning_config(const context& ctx,
instruction_ref ins,
const operation&,
bool exhaustive) const
{
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
tc.problem = to_value(shapes);
if(exhaustive)
{
std::vector<std::size_t> sizes;
transform(range(1, 64), std::back_inserter(sizes), [](auto i) { return i * 4; });
for(auto tile_h : sizes)
{
for(auto tile_w : sizes)
{
auto block_size = tile_h * tile_w;
if(block_size > 1024)
continue;
if(block_size < ctx.get_current_device().get_wavefront_size())
continue;
if((block_size % ctx.get_current_device().get_wavefront_size()) != 0)
continue;
for(auto opt : {1, 2, 4, 8})
tc.solutions.push_back(
{{"tile_h", tile_h}, {"tile_w", tile_w}, {"noutputs", opt}});
}
}
}
else
{
tc.solutions.push_back({{"tile_h", 8}, {"tile_w", 32}, {"noutputs", 1}});

tc.solutions.push_back({{"tile_h", 8}, {"tile_w", 8}, {"noutputs", 8}});
tc.solutions.push_back({{"tile_h", 8}, {"tile_w", 16}, {"noutputs", 2}});
tc.solutions.push_back({{"tile_h", 8}, {"tile_w", 64}, {"noutputs", 4}});
tc.solutions.push_back({{"tile_h", 8}, {"tile_w", 64}, {"noutputs", 8}});
tc.solutions.push_back({{"tile_h", 16}, {"tile_w", 8}, {"noutputs", 4}});
tc.solutions.push_back({{"tile_h", 16}, {"tile_w", 16}, {"noutputs", 2}});
tc.solutions.push_back({{"tile_h", 16}, {"tile_w", 64}, {"noutputs", 4}});
tc.solutions.push_back({{"tile_h", 32}, {"tile_w", 16}, {"noutputs", 8}});
tc.solutions.push_back({{"tile_h", 32}, {"tile_w", 32}, {"noutputs", 1}});
tc.solutions.push_back({{"tile_h", 40}, {"tile_w", 12}, {"noutputs", 1}});
tc.solutions.push_back({{"tile_h", 48}, {"tile_w", 16}, {"noutputs", 1}});
tc.solutions.push_back({{"tile_h", 56}, {"tile_w", 4}, {"noutputs", 1}});
tc.solutions.push_back({{"tile_h", 76}, {"tile_w", 8}, {"noutputs", 8}});
tc.solutions.push_back({{"tile_h", 128}, {"tile_w", 8}, {"noutputs", 8}});
}
return tc;
}
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
12 changes: 12 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,18 @@ constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U,
return integral_const_array<T, f(Xs, Ys)...>{};
}

template <class T, T... Xs, class U, U... Ys>
constexpr auto join(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>)
{
return integral_const_array<T, Xs..., Ys...>{};
}

template <class T, T... Xs, class U, U... Ys, class... Arrays>
constexpr auto join(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, Arrays...)
{
return join(integral_const_array<T, Xs..., Ys...>{}, Arrays{}...);
}

template <class F>
constexpr auto return_array_c(F f)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CHANNELWISE_CONV_HPP
#define MIGRAPHX_GUARD_KERNELS_CHANNELWISE_CONV_HPP

#include <migraphx/kernels/spatial_tiler.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/copy.hpp>

namespace migraphx {

template <class TileLens,
index_int NTiles,
class Padding,
class F,
class Output,
class Input,
class Weights,
class... Inputs>
__device__ void
channelwise_conv(TileLens, Padding, F f, Output output, Input x, Weights w, Inputs... inputs)
{
auto idx = make_index();
auto tiler = make_spatial_tiler<NTiles>(idx, TileLens{}, get_shape_c<Output>{}, Padding{});

__shared__ decltype(tiler.template shared_allocate<Input>()) smem;

auto x_ch = tiler.copy(x, smem);
auto w_ch = tiler.slice(w);
auto out_ch = tiler.slice(output);
auto xs_pack = pack(tiler.slice(inputs)...);

using type = typename Output::type;
array<type, decltype(w_ch.get_shape().elements()){}> wregs_arr;
auto wregs = make_tensor_view(wregs_arr.begin(), make_packed_shape(w_ch.get_shape()));
copy(w_ch.begin(), w_ch.end(), wregs.begin());

__syncthreads();

tiler.for_each([&](auto out_pos, auto out_multi) {
type acc = 0;
repeat(wregs.get_shape().elements(), [&](auto ki) {
auto k_multi = wregs.get_shape().multi(ki);
acc += x_ch[out_multi + k_multi] * wregs[k_multi];
});
xs_pack([&](auto... xs) { out_ch[out_pos] = f(acc, xs[out_pos]...); });
});
Comment thread
pfultz2 marked this conversation as resolved.
}

} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CHANNELWISE_CONV_HPP
Loading
Loading