diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp new file mode 100644 index 00000000000..f1a9716de4f --- /dev/null +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -0,0 +1,99 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct mean_variance_norm : op_parser +{ + std::set valid_types = { + shape::bf16_type, shape::double_type, shape::float_type, shape::half_type}; + + std::vector operators() const + { + return {{"MeanVarianceNormalization", "mean_variance_norm"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + const auto dtype = args[0]->get_shape().type(); + const auto literal_dtype = dtype; + + if(not contains(valid_types, dtype)) + { + MIGRAPHX_THROW(opd.onnx_name + ": invalid output type: " + std::to_string(dtype) + + ". Valid types are (bfloat16), (double), (float) and (half)"); + } + + const auto& x = args[0]; + + const auto eps_default = (dtype == shape::half_type) ? 1e-7f : 1e-9f; + const auto axes_default = std::vector{0, 2, 3}; + + auto eps = eps_default; + if (contains(info.attributes, "epsilon")) + { + eps = parser.parse_value(info.attributes.at("epsilon")).at(); + } + + auto axes = axes_default; + auto axes_min_size = axes.size(); + if (contains(info.attributes, "axes")) + { + axes.assign(info.attributes.at("axes").ints().begin(), info.attributes.at("axes").ints().end()); + axes_min_size = axes.size(); + } + + if (x->get_shape().ndim() < axes_min_size) + { + MIGRAPHX_THROW(opd.onnx_name + ": input dimension has value: " + std::to_string(x->get_shape().ndim()) + + ". It sould be greater or equal to: " + std::to_string(axes_min_size)); + } + + auto expected_val_x = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); + auto expected_val_sqr_x = info.add_common_op("mul", expected_val_x, expected_val_x); + auto x_sqr = info.add_common_op("mul", x, x); + auto expected_val_x_sqr = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqr); + auto std_sqr = info.add_common_op("sub", expected_val_x_sqr, expected_val_sqr_x); + auto std = info.add_common_op("sqrt", std_sqr); + auto numerator = info.add_common_op("sub", x, expected_val_x); + auto eps_literal = info.add_literal(literal{shape{literal_dtype}, {eps}}); + auto denominator = info.add_common_op("add", std, eps_literal); + auto y = info.add_common_op("div", numerator, denominator); + + return y; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 298607cac84..d3ab0a6bfe3 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15767,3 +15767,19 @@ def scan_arg_shapes_mismatch_test(): ) return ([node], [init_state, scan_ins1, scan_ins2], [final_state, scan_outs]) + +@onnx_test() +def mean_variance_norm_test(): + return mvn_n_rank_test_base([2, 3], [3, 3, 3, 1]) + +@onnx_test() +def mean_variance_norm_default_axes_test(): + return mvn_default_axes_test_base([3, 3, 3, 1]) + +@onnx_test() +def mean_variance_norm_invalid_type_test(): + return mvn_default_axes_test_base([3, 3, 3, 1], type=TensorProto.INT8) + +@onnx_test() +def mean_variance_norm_invalid_axes_test(): + return mvn_n_rank_test_base(axes=[2, 3, 1, 4], dims=[3, 1]) diff --git a/test/onnx/mean_variance_norm_default_axes_test.onnx b/test/onnx/mean_variance_norm_default_axes_test.onnx new file mode 100644 index 00000000000..9403ee5090d --- /dev/null +++ b/test/onnx/mean_variance_norm_default_axes_test.onnx @@ -0,0 +1,15 @@ + $mean_variance_norm_default_axes_test: +& +dataout"MeanVarianceNormalization$mean_variance_norm_default_axes_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_invalid_axes_test.onnx b/test/onnx/mean_variance_norm_invalid_axes_test.onnx new file mode 100644 index 00000000000..4fa86ca3dac --- /dev/null +++ b/test/onnx/mean_variance_norm_invalid_axes_test.onnx @@ -0,0 +1,12 @@ + $mean_variance_norm_invalid_axes_test: +9 +dataout"MeanVarianceNormalization* +axes@@@@$mean_variance_norm_invalid_axes_testZ +data +  + +b +out +  + +B \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_invalid_type_test.onnx b/test/onnx/mean_variance_norm_invalid_type_test.onnx new file mode 100644 index 00000000000..3d416adcdac --- /dev/null +++ b/test/onnx/mean_variance_norm_invalid_type_test.onnx @@ -0,0 +1,15 @@ + $mean_variance_norm_invalid_type_test: +& +dataout"MeanVarianceNormalization$mean_variance_norm_invalid_type_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx new file mode 100644 index 00000000000..f9155b157cf --- /dev/null +++ b/test/onnx/mean_variance_norm_test.onnx @@ -0,0 +1,16 @@ + mean_variance_norm_test: +5 +dataout"MeanVarianceNormalization* +axes@@mean_variance_norm_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp new file mode 100644 index 00000000000..8c3c17380ba --- /dev/null +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -0,0 +1,76 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + + #include + #include + +TEST_CASE(mean_variance_norm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + const std::vector dims{3, 3, 3, 1}; + const std::vector axes{2, 3}; + migraphx::shape s1{migraphx::shape::float_type, dims}; + + const float eps_default = 1e-7f; + + auto x = mm->add_parameter("data", s1); + + auto expected_val_x = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x); + auto expected_val_sqr_x = add_common_op(*mm, migraphx::make_op("mul"), {expected_val_x, expected_val_x}); + auto x_sqr = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); + auto expected_val_x_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x_sqr); + auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {expected_val_x_sqr, expected_val_sqr_x}); + auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); + auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {x, expected_val_x}); + auto eps_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {eps_default}}); + auto denominator = add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); + add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); + + auto prog = optimize_onnx("mean_variance_norm_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(mean_variance_norm_default_axes_test) +{ + migraphx::program p; + const auto prog = optimize_onnx("mean_variance_norm_default_axes_test.onnx"); + const auto* mm = prog.get_main_module(); + const auto it = std::find_if(mm->begin(), mm->end(), [](auto instr){ return instr.name() == "reduce_mean";}); + const auto axes = (*it).get_operator().to_value().at("axes").get_array(); + const auto axes_default = std::vector{0, 2, 3}; + + EXPECT(axes == axes_default); +} + +TEST_CASE(mean_variance_norm_invalid_type_test) +{ + EXPECT(test::throws([&] { optimize_onnx("mean_variance_norm_invalid_type_test.onnx"); })); +} + +TEST_CASE(mean_variance_norm_invalid_axes_test) +{ + EXPECT(test::throws([&] { optimize_onnx("mean_variance_norm_invalid_axes_test.onnx"); })); +} diff --git a/test/onnx/verify/mean_variance_norm_val_test.cpp b/test/onnx/verify/mean_variance_norm_val_test.cpp new file mode 100644 index 00000000000..d9323ca19ae --- /dev/null +++ b/test/onnx/verify/mean_variance_norm_val_test.cpp @@ -0,0 +1,76 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +#include + +TEST_CASE(mean_variance_norm_val_test) +{ + // example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py + migraphx::program p = read_onnx("mean_variance_norm_default_axes_test.onnx"); + + p.compile(migraphx::make_target("ref")); + + const migraphx::shape s{migraphx::shape::float_type, {3, 3, 3, 1}}; + std::vector x = { + 0.843968, 0.566514, 0.0583673, + 0.0291637, 0.129643, 0.50602, + 0.795383, 0.941135, 0.954657, + + 0.177309, 0.461921, 0.264804, + 0.674684, 0.0166526, 0.624731, + 0.924084, 0.972234, 0.119657, + + 0.413562, 0.912937, 0.593301, + 0.819299, 0.78626, 0.117998, + 0.692484, 0.541194, 0.0751322 + }; + + migraphx::parameter_map p_map; + p_map["data"] = migraphx::argument(s, x.data()); + + auto result = p.eval(p_map).back(); + std::vector result_vector(9); + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + const std::vector expected_result = + { + 1.35464, 0.330535, -1.54508, + -1.21068, -0.892595, 0.298881, + 0.380831, 0.818088, 0.858656, + + -1.10606, -0.0555288, -0.783103, + 0.832814, -1.25028, 0.674679, + 0.766937, 0.911387, -1.64636, + + -0.234028, 1.60921, 0.429406, + 1.29061, 1.18602, -0.929458, + 0.0721332, -0.38174, -1.77993 + }; + EXPECT(migraphx::verify::verify_rms_range(result_vector, expected_result)); +}