From 30eda28016624b644ac4826954da2be9df6c4dee Mon Sep 17 00:00:00 2001 From: HTEC <> Date: Wed, 26 Mar 2025 13:38:44 +0000 Subject: [PATCH 01/20] implementing the parsing of mean variance normalization; adding parser test, verify test and associated onnx files --- src/onnx/parse_mean_variance_norm.cpp | 94 ++++++++++++++++++ test/onnx/gen_onnx.py | 49 +++++++++ test/onnx/mean_variance_norm_test.onnx | 16 +++ test/onnx/mean_variance_norm_val_test.onnx | Bin 0 -> 288 bytes test/onnx/parse/mean_variance_norm.cpp | 56 +++++++++++ .../verify/mean_variance_norm_val_test.cpp | 57 +++++++++++ 6 files changed, 272 insertions(+) create mode 100644 src/onnx/parse_mean_variance_norm.cpp create mode 100644 test/onnx/mean_variance_norm_test.onnx create mode 100644 test/onnx/mean_variance_norm_val_test.onnx create mode 100644 test/onnx/parse/mean_variance_norm.cpp create mode 100644 test/onnx/verify/mean_variance_norm_val_test.cpp diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp new file mode 100644 index 00000000000..11cd063bba4 --- /dev/null +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -0,0 +1,94 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct mean_variance_norm : op_parser +{ + std::set valid_types = { + shape::bf16_type, shape::double_type, shape::float_type}; + + std::vector operators() const + { + return {{"MeanVarianceNormalization", "mean_variance_norm"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + const auto& X = args[0]; + + const auto Epsilon_default = 1e-9f; + const auto axes_default = std::vector{0, 2, 3}; + + auto Epsilon = Epsilon_default; + if (contains(info.attributes, "epsilon")) + { + Epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + + auto axes = axes_default; + auto axes_min_size = axes.size(); + if (contains(info.attributes, "axes")) + { + axes.assign(info.attributes.at("axes").ints().begin(), info.attributes.at("axes").ints().end()); + axes_min_size = axes.size(); + } + assert(X->get_shape().ndim() >= axes_min_size); + + const auto dtype = args[0]->get_shape().type(); + const auto literal_dtype = dtype; + + if(not contains(valid_types, dtype)) + { + MIGRAPHX_THROW(opd.onnx_name + ": invalid output type: " + std::to_string(dtype) + + ". Valid types are (bfloat16), (double), and (float)."); + } + + auto E_X = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X); + auto E_sqr_X = info.add_common_op("mul", E_X, E_X); + auto X_sqr = info.add_common_op("mul", X, X); + auto E_X_sqr = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X_sqr); + auto std_sqr = info.add_common_op("sub", E_X_sqr, E_sqr_X); + auto std = info.add_common_op("sqrt", std_sqr); + auto numerator = info.add_common_op("sub", X, E_X); + auto Epsilon_literal= info.add_literal(literal{shape{literal_dtype}, {Epsilon}}); + auto denominator= info.add_common_op("add", std, Epsilon_literal); + auto Y = info.add_common_op("div", numerator, denominator); + + return Y; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 298607cac84..ed342b99d6b 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15767,3 +15767,52 @@ def scan_arg_shapes_mismatch_test(): ) return ([node], [init_state, scan_ins1, scan_ins2], [final_state, scan_outs]) + +@onnx_test() +def mean_variance_norm_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + ax = [2, 3] + + node = onnx.helper.make_node('MeanVarianceNormalization', + inputs=['x'], + outputs=['y'], + axes=ax, + ) + return ([node], [x], [y]) + +@onnx_test() +def mean_variance_norm_val_test(): + # example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py + x = np.array( + [ + [ + [[0.8439683], [0.5665144], [0.05836735]], + [[0.02916367], [0.12964272], [0.5060197]], + [[0.79538304], [0.9411346], [0.9546573]], + ], + [ + [[0.17730942], [0.46192095], [0.26480448]], + [[0.6746842], [0.01665257], [0.62473077]], + [[0.9240844], [0.9722341], [0.11965699]], + ], + [ + [[0.41356155], [0.9129373], [0.59330076]], + [[0.81929934], [0.7862604], [0.11799799]], + [[0.69248444], [0.54119414], [0.07513223]], + ], + ], + dtype=np.float32, + ) + + x_tensor = helper.make_tensor(name='x_tensor', data_type=TensorProto.FLOAT, dims=x.shape, vals=x.flatten().astype(np.float)) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, x.shape) + ax = [0, 2, 3] + + node = onnx.helper.make_node('MeanVarianceNormalization', + inputs=['x_tensor'], + outputs=['y'], + axes=ax, + ) + + return ([node], [], [y], [x_tensor]) \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx new file mode 100644 index 00000000000..26d74a611fa --- /dev/null +++ b/test/onnx/mean_variance_norm_test.onnx @@ -0,0 +1,16 @@ + mean_variance_norm_test: +0 +xy"MeanVarianceNormalization* +axes@@mean_variance_norm_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_val_test.onnx b/test/onnx/mean_variance_norm_val_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a61ccb77ab329d5a005b9a70abcda47eb9bf2220 GIT binary patch literal 288 zcmdaj)qX(5i8+~7i6xo&d0PBjEQuAV#SRP(Ob*No7}@a}t<}WH!3>6s0*p#IetZ%3 zV(fzUvxM?(m%VsrBXE$#uBnUB-g2I|z0tb@`<#>#`&koA?VdNhu`8L9}ZB+bQGDI~@v!oescz{SMDjD#7Tm;{6Y%)Vc^ literal 0 HcmV?d00001 diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp new file mode 100644 index 00000000000..8f18e87489d --- /dev/null +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + + #include + #include + +TEST_CASE(mean_variance_norm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector dims{3, 4, 5, 6}; + migraphx::shape s1{migraphx::shape::float_type, dims}; + + const float epsilon_default = 1e-9f; + + auto X = mm->add_parameter("x", s1); + + auto E_X = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), X); + auto E_sqr_X = add_common_op(*mm, migraphx::make_op("mul"), {E_X, E_X}); + auto X_sqr = add_common_op(*mm, migraphx::make_op("mul"), {X, X}); + auto E_X_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), X_sqr); + auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {E_X_sqr, E_sqr_X}); + auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); + auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {X, E_X}); + auto Epsilon_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {epsilon_default}}); + auto denominator= add_common_op(*mm, migraphx::make_op("add"), {std, Epsilon_literal}); + auto Y = add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); + + mm->add_return({Y}); + + migraphx::onnx_options options; + auto prog = read_onnx("mean_variance_norm_test.onnx", options); + EXPECT(p == prog); +} diff --git a/test/onnx/verify/mean_variance_norm_val_test.cpp b/test/onnx/verify/mean_variance_norm_val_test.cpp new file mode 100644 index 00000000000..509107cf5f5 --- /dev/null +++ b/test/onnx/verify/mean_variance_norm_val_test.cpp @@ -0,0 +1,57 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +#include + +TEST_CASE(mean_variance_norm_val_test) +{ + migraphx::program p = read_onnx("mean_variance_norm_val_test.onnx"); + + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector result_vector(9); + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py + const std::vector expected_result = + { + 1.35464, 0.330535, -1.54508, + -1.21068, -0.892595, 0.298881, + 0.380831, 0.818088, 0.858656, + + -1.10606, -0.0555288, -0.783103, + 0.832814, -1.25028, 0.674679, + 0.766937, 0.911387, -1.64636, + + -0.234028, 1.60921, 0.429406, + 1.29061, 1.18602, -0.929458, + 0.0721332, -0.38174, -1.77993 + }; + EXPECT(migraphx::verify::verify_rms_range(result_vector, expected_result)); +} From f395a8fd860215727cc0af7c89b2cc7933aa8654 Mon Sep 17 00:00:00 2001 From: HTEC <> Date: Wed, 26 Mar 2025 15:07:00 +0000 Subject: [PATCH 02/20] have the type check at the top of the parsing function --- src/onnx/parse_mean_variance_norm.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index 11cd063bba4..2247cf84726 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -45,6 +45,15 @@ struct mean_variance_norm : op_parser const onnx_parser::node_info& info, std::vector args) const { + const auto dtype = args[0]->get_shape().type(); + const auto literal_dtype = dtype; + + if(not contains(valid_types, dtype)) + { + MIGRAPHX_THROW(opd.onnx_name + ": invalid output type: " + std::to_string(dtype) + + ". Valid types are (bfloat16), (double), and (float)."); + } + const auto& X = args[0]; const auto Epsilon_default = 1e-9f; @@ -65,15 +74,6 @@ struct mean_variance_norm : op_parser } assert(X->get_shape().ndim() >= axes_min_size); - const auto dtype = args[0]->get_shape().type(); - const auto literal_dtype = dtype; - - if(not contains(valid_types, dtype)) - { - MIGRAPHX_THROW(opd.onnx_name + ": invalid output type: " + std::to_string(dtype) + - ". Valid types are (bfloat16), (double), and (float)."); - } - auto E_X = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X); auto E_sqr_X = info.add_common_op("mul", E_X, E_X); auto X_sqr = info.add_common_op("mul", X, X); From 650adc4e70a35034766a1a98229857a51279bcec Mon Sep 17 00:00:00 2001 From: HTEC <> Date: Wed, 26 Mar 2025 15:14:48 +0000 Subject: [PATCH 03/20] tiny renaming --- src/onnx/parse_mean_variance_norm.cpp | 10 +++++----- test/onnx/parse/mean_variance_norm.cpp | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index 2247cf84726..0a3b54c7f61 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -56,13 +56,13 @@ struct mean_variance_norm : op_parser const auto& X = args[0]; - const auto Epsilon_default = 1e-9f; + const auto eps_default = 1e-9f; const auto axes_default = std::vector{0, 2, 3}; - auto Epsilon = Epsilon_default; + auto eps = eps_default; if (contains(info.attributes, "epsilon")) { - Epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + eps = parser.parse_value(info.attributes.at("epsilon")).at(); } auto axes = axes_default; @@ -81,8 +81,8 @@ struct mean_variance_norm : op_parser auto std_sqr = info.add_common_op("sub", E_X_sqr, E_sqr_X); auto std = info.add_common_op("sqrt", std_sqr); auto numerator = info.add_common_op("sub", X, E_X); - auto Epsilon_literal= info.add_literal(literal{shape{literal_dtype}, {Epsilon}}); - auto denominator= info.add_common_op("add", std, Epsilon_literal); + auto eps_literal= info.add_literal(literal{shape{literal_dtype}, {eps}}); + auto denominator= info.add_common_op("add", std, eps_literal); auto Y = info.add_common_op("div", numerator, denominator); return Y; diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 8f18e87489d..8a4b0defe79 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -33,10 +33,10 @@ TEST_CASE(mean_variance_norm_test) std::vector dims{3, 4, 5, 6}; migraphx::shape s1{migraphx::shape::float_type, dims}; - const float epsilon_default = 1e-9f; + const float eps_default = 1e-9f; auto X = mm->add_parameter("x", s1); - + auto E_X = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), X); auto E_sqr_X = add_common_op(*mm, migraphx::make_op("mul"), {E_X, E_X}); auto X_sqr = add_common_op(*mm, migraphx::make_op("mul"), {X, X}); @@ -44,8 +44,8 @@ TEST_CASE(mean_variance_norm_test) auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {E_X_sqr, E_sqr_X}); auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {X, E_X}); - auto Epsilon_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {epsilon_default}}); - auto denominator= add_common_op(*mm, migraphx::make_op("add"), {std, Epsilon_literal}); + auto eps_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {eps_default}}); + auto denominator= add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); auto Y = add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); mm->add_return({Y}); From b5e470a255f8e19029c7d11a634c6a3b4f277177 Mon Sep 17 00:00:00 2001 From: Gyula Chinoradszki Date: Tue, 1 Apr 2025 11:09:32 +0200 Subject: [PATCH 04/20] 'resolving review comments part 1' --- src/onnx/parse_mean_variance_norm.cpp | 37 +++++++++++++++------------ 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index 0a3b54c7f61..0f94a479424 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -33,7 +33,7 @@ namespace onnx { struct mean_variance_norm : op_parser { std::set valid_types = { - shape::bf16_type, shape::double_type, shape::float_type}; + shape::bf16_type, shape::double_type, shape::float_type, shape::half_type}; std::vector operators() const { @@ -51,12 +51,12 @@ struct mean_variance_norm : op_parser if(not contains(valid_types, dtype)) { MIGRAPHX_THROW(opd.onnx_name + ": invalid output type: " + std::to_string(dtype) + - ". Valid types are (bfloat16), (double), and (float)."); + ". Valid types are (bfloat16), (double), (float) and (half)"); } - const auto& X = args[0]; + const auto& x = args[0]; - const auto eps_default = 1e-9f; + const auto eps_default = 1e-7f; const auto axes_default = std::vector{0, 2, 3}; auto eps = eps_default; @@ -72,20 +72,25 @@ struct mean_variance_norm : op_parser axes.assign(info.attributes.at("axes").ints().begin(), info.attributes.at("axes").ints().end()); axes_min_size = axes.size(); } - assert(X->get_shape().ndim() >= axes_min_size); + + if (x->get_shape().ndim() < axes_min_size) + { + MIGRAPHX_THROW(opd.onnx_name + ": input dimension has value: " + std::to_string(x->get_shape().ndim()) + + ". It sould be greater or equal to: " + std::to_string(axes_min_size)) + } - auto E_X = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X); - auto E_sqr_X = info.add_common_op("mul", E_X, E_X); - auto X_sqr = info.add_common_op("mul", X, X); - auto E_X_sqr = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X_sqr); - auto std_sqr = info.add_common_op("sub", E_X_sqr, E_sqr_X); - auto std = info.add_common_op("sqrt", std_sqr); - auto numerator = info.add_common_op("sub", X, E_X); - auto eps_literal= info.add_literal(literal{shape{literal_dtype}, {eps}}); - auto denominator= info.add_common_op("add", std, eps_literal); - auto Y = info.add_common_op("div", numerator, denominator); + auto expected_val_x = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); + auto expected_val_sqr_x = info.add_common_op("mul", expected_val_x, expected_val_x); + auto x_sqr = info.add_common_op("mul", x, x); + auto expected_val_x_sqr = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqr); + auto std_sqr = info.add_common_op("sub", expected_val_x_sqr, expected_val_sqr_x); + auto std = info.add_common_op("sqrt", std_sqr); + auto numerator = info.add_common_op("sub", x, expected_val_x); + auto eps_literal = info.add_literal(literal{shape{literal_dtype}, {eps}}); + auto denominator = info.add_common_op("add", std, eps_literal); + auto y = info.add_common_op("div", numerator, denominator); - return Y; + return y; } }; From 0a371b2ec43a34d40cded29bac516fe94583736f Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 11:12:45 +0000 Subject: [PATCH 05/20] add missing semicolon --- src/onnx/parse_mean_variance_norm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index 0f94a479424..49205f3e671 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -76,7 +76,7 @@ struct mean_variance_norm : op_parser if (x->get_shape().ndim() < axes_min_size) { MIGRAPHX_THROW(opd.onnx_name + ": input dimension has value: " + std::to_string(x->get_shape().ndim()) + - ". It sould be greater or equal to: " + std::to_string(axes_min_size)) + ". It sould be greater or equal to: " + std::to_string(axes_min_size)); } auto expected_val_x = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); From ce9da7eec8c81fc23982378c37bc12f9a28319cd Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 11:51:49 +0000 Subject: [PATCH 06/20] change epsilon default in parser test too --- test/onnx/parse/mean_variance_norm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 8a4b0defe79..a03679622af 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -33,7 +33,7 @@ TEST_CASE(mean_variance_norm_test) std::vector dims{3, 4, 5, 6}; migraphx::shape s1{migraphx::shape::float_type, dims}; - const float eps_default = 1e-9f; + const float eps_default = 1e-7f; auto X = mm->add_parameter("x", s1); From f86da7f58ae7da6f44d03a02edadad5a0a4b63ab Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:16:42 +0000 Subject: [PATCH 07/20] ident fix --- src/onnx/parse_mean_variance_norm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index 49205f3e671..a4ec778058f 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -54,9 +54,9 @@ struct mean_variance_norm : op_parser ". Valid types are (bfloat16), (double), (float) and (half)"); } - const auto& x = args[0]; + const auto& x = args[0]; - const auto eps_default = 1e-7f; + const auto eps_default = 1e-7f; const auto axes_default = std::vector{0, 2, 3}; auto eps = eps_default; From 01372ec43873b5d33dcfe352932d0fffff7da757 Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:18:10 +0000 Subject: [PATCH 08/20] renaming; using optimize_onnx instead of read_onny --- test/onnx/parse/mean_variance_norm.cpp | 28 ++++++++++++-------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index a03679622af..f2c28232728 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -35,22 +35,20 @@ TEST_CASE(mean_variance_norm_test) const float eps_default = 1e-7f; - auto X = mm->add_parameter("x", s1); - - auto E_X = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), X); - auto E_sqr_X = add_common_op(*mm, migraphx::make_op("mul"), {E_X, E_X}); - auto X_sqr = add_common_op(*mm, migraphx::make_op("mul"), {X, X}); - auto E_X_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), X_sqr); - auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {E_X_sqr, E_sqr_X}); - auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); - auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {X, E_X}); - auto eps_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {eps_default}}); - auto denominator= add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); - auto Y = add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); - - mm->add_return({Y}); + auto x = mm->add_parameter("x", s1); + + auto expected_val_x = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); + auto expected_val_sqr_x = add_common_op(*mm, migraphx::make_op("mul"), {expected_val_x, expected_val_x}); + auto x_sqr = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); + auto expected_val_x_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x_sqr); + auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {expected_val_x_sqr, expected_val_sqr_x}); + auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); + auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {x, expected_val_x}); + auto eps_literal = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {eps_default}}); + auto denominator = add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); + add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); migraphx::onnx_options options; - auto prog = read_onnx("mean_variance_norm_test.onnx", options); + auto prog = optimize_onnx("mean_variance_norm_test.onnx"); EXPECT(p == prog); } From 817769ffc55ecb389fb6f05213650905ca6e8d36 Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:36:11 +0000 Subject: [PATCH 09/20] remove onnx generator that included tensor values; remove associated onnx file; modify tensor shapes --- test/onnx/gen_onnx.py | 42 ++------------------- test/onnx/mean_variance_norm_test.onnx | 16 -------- test/onnx/mean_variance_norm_val_test.onnx | Bin 288 -> 0 bytes 3 files changed, 3 insertions(+), 55 deletions(-) delete mode 100644 test/onnx/mean_variance_norm_test.onnx delete mode 100644 test/onnx/mean_variance_norm_val_test.onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index ed342b99d6b..df4b4bbfc99 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15770,9 +15770,9 @@ def scan_arg_shapes_mismatch_test(): @onnx_test() def mean_variance_norm_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) - ax = [2, 3] + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 3, 1]) + ax = [0, 2, 3] node = onnx.helper.make_node('MeanVarianceNormalization', inputs=['x'], @@ -15780,39 +15780,3 @@ def mean_variance_norm_test(): axes=ax, ) return ([node], [x], [y]) - -@onnx_test() -def mean_variance_norm_val_test(): - # example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py - x = np.array( - [ - [ - [[0.8439683], [0.5665144], [0.05836735]], - [[0.02916367], [0.12964272], [0.5060197]], - [[0.79538304], [0.9411346], [0.9546573]], - ], - [ - [[0.17730942], [0.46192095], [0.26480448]], - [[0.6746842], [0.01665257], [0.62473077]], - [[0.9240844], [0.9722341], [0.11965699]], - ], - [ - [[0.41356155], [0.9129373], [0.59330076]], - [[0.81929934], [0.7862604], [0.11799799]], - [[0.69248444], [0.54119414], [0.07513223]], - ], - ], - dtype=np.float32, - ) - - x_tensor = helper.make_tensor(name='x_tensor', data_type=TensorProto.FLOAT, dims=x.shape, vals=x.flatten().astype(np.float)) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, x.shape) - ax = [0, 2, 3] - - node = onnx.helper.make_node('MeanVarianceNormalization', - inputs=['x_tensor'], - outputs=['y'], - axes=ax, - ) - - return ([node], [], [y], [x_tensor]) \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx deleted file mode 100644 index 26d74a611fa..00000000000 --- a/test/onnx/mean_variance_norm_test.onnx +++ /dev/null @@ -1,16 +0,0 @@ - mean_variance_norm_test: -0 -xy"MeanVarianceNormalization* -axes@@mean_variance_norm_testZ -x - - - - -b -y - - - - -B \ No newline at end of file diff --git a/test/onnx/mean_variance_norm_val_test.onnx b/test/onnx/mean_variance_norm_val_test.onnx deleted file mode 100644 index a61ccb77ab329d5a005b9a70abcda47eb9bf2220..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 288 zcmdaj)qX(5i8+~7i6xo&d0PBjEQuAV#SRP(Ob*No7}@a}t<}WH!3>6s0*p#IetZ%3 zV(fzUvxM?(m%VsrBXE$#uBnUB-g2I|z0tb@`<#>#`&koA?VdNhu`8L9}ZB+bQGDI~@v!oescz{SMDjD#7Tm;{6Y%)Vc^ From 67cdc0c21e8328b1ad9ca2558a99fed09583d7fc Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:45:17 +0000 Subject: [PATCH 10/20] add new onnx file; modify test a bit according to how it is generated in the gen_onnx.py file --- test/onnx/mean_variance_norm_test.onnx | Bin 0 -> 169 bytes test/onnx/parse/mean_variance_norm.cpp | 7 ++++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 test/onnx/mean_variance_norm_test.onnx diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..95d12317e448afa4fc15885cd8022af335bc95c8 GIT binary patch literal 169 zcmdOV literal 0 HcmV?d00001 diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index f2c28232728..3cb0bdf35ef 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -30,17 +30,18 @@ TEST_CASE(mean_variance_norm_test) migraphx::program p; auto* mm = p.get_main_module(); - std::vector dims{3, 4, 5, 6}; + const std::vector dims{3, 3, 3, 1}; + const std::vector axes{0, 2, 3}; migraphx::shape s1{migraphx::shape::float_type, dims}; const float eps_default = 1e-7f; auto x = mm->add_parameter("x", s1); - auto expected_val_x = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); + auto expected_val_x = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x); auto expected_val_sqr_x = add_common_op(*mm, migraphx::make_op("mul"), {expected_val_x, expected_val_x}); auto x_sqr = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); - auto expected_val_x_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x_sqr); + auto expected_val_x_sqr = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x_sqr); auto std_sqr = add_common_op(*mm, migraphx::make_op("sub"), {expected_val_x_sqr, expected_val_sqr_x}); auto std = add_common_op(*mm, migraphx::make_op("sqrt"), {std_sqr}); auto numerator = add_common_op(*mm, migraphx::make_op("sub"), {x, expected_val_x}); From e10e776c6d1441e581a4b8f977b6e91b0a001779 Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:49:51 +0000 Subject: [PATCH 11/20] delete onnx with wrong axes value --- test/onnx/mean_variance_norm_test.onnx | Bin 169 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/onnx/mean_variance_norm_test.onnx diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx deleted file mode 100644 index 95d12317e448afa4fc15885cd8022af335bc95c8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 169 zcmdOV From f660e4a787d9366826ef614069e5f09ca57a6902 Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 15:52:47 +0000 Subject: [PATCH 12/20] yet another onnx file and the modified test --- test/onnx/gen_onnx.py | 2 +- test/onnx/mean_variance_norm_test.onnx | 16 ++++++++++++++++ test/onnx/parse/mean_variance_norm.cpp | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 test/onnx/mean_variance_norm_test.onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index df4b4bbfc99..fdd18a93960 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15772,7 +15772,7 @@ def scan_arg_shapes_mismatch_test(): def mean_variance_norm_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3, 1]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 3, 1]) - ax = [0, 2, 3] + ax = [2, 3] node = onnx.helper.make_node('MeanVarianceNormalization', inputs=['x'], diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx new file mode 100644 index 00000000000..1e3d94bafcd --- /dev/null +++ b/test/onnx/mean_variance_norm_test.onnx @@ -0,0 +1,16 @@ + mean_variance_norm_test: +0 +xy"MeanVarianceNormalization* +axes@@mean_variance_norm_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 3cb0bdf35ef..68c892b46f9 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -31,7 +31,7 @@ TEST_CASE(mean_variance_norm_test) auto* mm = p.get_main_module(); const std::vector dims{3, 3, 3, 1}; - const std::vector axes{0, 2, 3}; + const std::vector axes{2, 3}; migraphx::shape s1{migraphx::shape::float_type, dims}; const float eps_default = 1e-7f; From 38e132a1d36daaeb79aaa6b0e2602127fe8e9b8d Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 16:03:57 +0000 Subject: [PATCH 13/20] reuse of existing helper function --- test/onnx/mean_variance_norm_test.onnx | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/onnx/mean_variance_norm_test.onnx b/test/onnx/mean_variance_norm_test.onnx index 1e3d94bafcd..f9155b157cf 100644 --- a/test/onnx/mean_variance_norm_test.onnx +++ b/test/onnx/mean_variance_norm_test.onnx @@ -1,14 +1,14 @@ - mean_variance_norm_test: -0 -xy"MeanVarianceNormalization* -axes@@mean_variance_norm_testZ -x + mean_variance_norm_test: +5 +dataout"MeanVarianceNormalization* +axes@@mean_variance_norm_testZ +data     -b -y +b +out    From 8148b1d9919acbf9e2ad238893285137e28a0b49 Mon Sep 17 00:00:00 2001 From: gchinora Date: Tue, 1 Apr 2025 16:05:21 +0000 Subject: [PATCH 14/20] 'reuse existing helper function' --- test/onnx/gen_onnx.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index fdd18a93960..5ca6066cfde 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15770,13 +15770,4 @@ def scan_arg_shapes_mismatch_test(): @onnx_test() def mean_variance_norm_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 3, 3, 1]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 3, 3, 1]) - ax = [2, 3] - - node = onnx.helper.make_node('MeanVarianceNormalization', - inputs=['x'], - outputs=['y'], - axes=ax, - ) - return ([node], [x], [y]) + return mvn_n_rank_test_base([2, 3], [3, 3, 3, 1]) From f8f488388d3f71cd96ac5cffa82cd0d6fb283de0 Mon Sep 17 00:00:00 2001 From: gchinora Date: Wed, 2 Apr 2025 07:35:04 +0000 Subject: [PATCH 15/20] fix; change parameter name in test --- test/onnx/parse/mean_variance_norm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 68c892b46f9..694f5f9b44e 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -36,7 +36,7 @@ TEST_CASE(mean_variance_norm_test) const float eps_default = 1e-7f; - auto x = mm->add_parameter("x", s1); + auto x = mm->add_parameter("data", s1); auto expected_val_x = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x); auto expected_val_sqr_x = add_common_op(*mm, migraphx::make_op("mul"), {expected_val_x, expected_val_x}); From 3dfd9ae4a3942d5bf44e6cab7886798462b6628c Mon Sep 17 00:00:00 2001 From: gchinora Date: Wed, 2 Apr 2025 07:37:19 +0000 Subject: [PATCH 16/20] 'Add test; testing if axes are defaulted if not specified' --- test/onnx/gen_onnx.py | 4 ++++ .../mean_variance_norm_default_axes_test.onnx | 15 +++++++++++++++ test/onnx/parse/mean_variance_norm.cpp | 13 ++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 test/onnx/mean_variance_norm_default_axes_test.onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 5ca6066cfde..7089fb3a3f0 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15771,3 +15771,7 @@ def scan_arg_shapes_mismatch_test(): @onnx_test() def mean_variance_norm_test(): return mvn_n_rank_test_base([2, 3], [3, 3, 3, 1]) + +@onnx_test() +def mean_variance_norm_default_axes_test(): + return mvn_default_axes_test_base([3, 3, 3, 1]) diff --git a/test/onnx/mean_variance_norm_default_axes_test.onnx b/test/onnx/mean_variance_norm_default_axes_test.onnx new file mode 100644 index 00000000000..9403ee5090d --- /dev/null +++ b/test/onnx/mean_variance_norm_default_axes_test.onnx @@ -0,0 +1,15 @@ + $mean_variance_norm_default_axes_test: +& +dataout"MeanVarianceNormalization$mean_variance_norm_default_axes_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 694f5f9b44e..8a735521342 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -49,7 +49,18 @@ TEST_CASE(mean_variance_norm_test) auto denominator = add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); - migraphx::onnx_options options; auto prog = optimize_onnx("mean_variance_norm_test.onnx"); EXPECT(p == prog); } + +TEST_CASE(mean_variance_norm_default_axes_test) +{ + migraphx::program p; + const auto prog = optimize_onnx("mean_variance_norm_default_axes_test.onnx"); + const auto* mm = prog.get_main_module(); + const auto it = std::find_if(mm->begin(), mm->end(), [](auto instr){ return instr.name() == "reduce_mean";}); + const auto axes = (*it).get_operator().to_value().at("axes").get_array(); + const auto axes_default = std::vector{0, 2, 3}; + + EXPECT(axes == axes_default); +} From ab9928fc05e472724b8853392fd57c21469f8873 Mon Sep 17 00:00:00 2001 From: gchinora Date: Wed, 2 Apr 2025 08:43:01 +0000 Subject: [PATCH 17/20] 'removing data from python and moving it to the verify test' --- .../verify/mean_variance_norm_val_test.cpp | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/test/onnx/verify/mean_variance_norm_val_test.cpp b/test/onnx/verify/mean_variance_norm_val_test.cpp index 509107cf5f5..d9323ca19ae 100644 --- a/test/onnx/verify/mean_variance_norm_val_test.cpp +++ b/test/onnx/verify/mean_variance_norm_val_test.cpp @@ -31,14 +31,33 @@ TEST_CASE(mean_variance_norm_val_test) { - migraphx::program p = read_onnx("mean_variance_norm_val_test.onnx"); + // example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py + migraphx::program p = read_onnx("mean_variance_norm_default_axes_test.onnx"); p.compile(migraphx::make_target("ref")); - auto result = p.eval({}).back(); + + const migraphx::shape s{migraphx::shape::float_type, {3, 3, 3, 1}}; + std::vector x = { + 0.843968, 0.566514, 0.0583673, + 0.0291637, 0.129643, 0.50602, + 0.795383, 0.941135, 0.954657, + + 0.177309, 0.461921, 0.264804, + 0.674684, 0.0166526, 0.624731, + 0.924084, 0.972234, 0.119657, + + 0.413562, 0.912937, 0.593301, + 0.819299, 0.78626, 0.117998, + 0.692484, 0.541194, 0.0751322 + }; + + migraphx::parameter_map p_map; + p_map["data"] = migraphx::argument(s, x.data()); + + auto result = p.eval(p_map).back(); std::vector result_vector(9); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - // example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py const std::vector expected_result = { 1.35464, 0.330535, -1.54508, @@ -53,5 +72,5 @@ TEST_CASE(mean_variance_norm_val_test) 1.29061, 1.18602, -0.929458, 0.0721332, -0.38174, -1.77993 }; - EXPECT(migraphx::verify::verify_rms_range(result_vector, expected_result)); + EXPECT(migraphx::verify::verify_rms_range(result_vector, expected_result)); } From 85a4ec961488da2793de6f95c04701e18580e963 Mon Sep 17 00:00:00 2001 From: gchinora Date: Wed, 2 Apr 2025 09:35:37 +0000 Subject: [PATCH 18/20] 'adding test for invalid type case' --- test/onnx/gen_onnx.py | 4 ++++ .../mean_variance_norm_invalid_type_test.onnx | 15 +++++++++++++++ test/onnx/parse/mean_variance_norm.cpp | 5 +++++ 3 files changed, 24 insertions(+) create mode 100644 test/onnx/mean_variance_norm_invalid_type_test.onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 7089fb3a3f0..06809f20cac 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15775,3 +15775,7 @@ def mean_variance_norm_test(): @onnx_test() def mean_variance_norm_default_axes_test(): return mvn_default_axes_test_base([3, 3, 3, 1]) + +@onnx_test() +def mean_variance_norm_invalid_type_test(): + return mvn_default_axes_test_base([3, 3, 3, 1], type=TensorProto.INT8) diff --git a/test/onnx/mean_variance_norm_invalid_type_test.onnx b/test/onnx/mean_variance_norm_invalid_type_test.onnx new file mode 100644 index 00000000000..3d416adcdac --- /dev/null +++ b/test/onnx/mean_variance_norm_invalid_type_test.onnx @@ -0,0 +1,15 @@ + $mean_variance_norm_invalid_type_test: +& +dataout"MeanVarianceNormalization$mean_variance_norm_invalid_type_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 8a735521342..8a60cd950c3 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -64,3 +64,8 @@ TEST_CASE(mean_variance_norm_default_axes_test) EXPECT(axes == axes_default); } + +TEST_CASE(mean_variance_norm_invalid_type_test) +{ + EXPECT(test::throws([&] { optimize_onnx("mean_variance_norm_invalid_type_test.onnx"); })); +} From 9725f4be9038769c4ceebcf520953c975fd05e9d Mon Sep 17 00:00:00 2001 From: gchinora Date: Wed, 2 Apr 2025 09:45:36 +0000 Subject: [PATCH 19/20] 'adding test for invalid input axes shape' --- test/onnx/gen_onnx.py | 4 ++++ test/onnx/mean_variance_norm_invalid_axes_test.onnx | 12 ++++++++++++ test/onnx/parse/mean_variance_norm.cpp | 5 +++++ 3 files changed, 21 insertions(+) create mode 100644 test/onnx/mean_variance_norm_invalid_axes_test.onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 06809f20cac..d3ab0a6bfe3 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -15779,3 +15779,7 @@ def mean_variance_norm_default_axes_test(): @onnx_test() def mean_variance_norm_invalid_type_test(): return mvn_default_axes_test_base([3, 3, 3, 1], type=TensorProto.INT8) + +@onnx_test() +def mean_variance_norm_invalid_axes_test(): + return mvn_n_rank_test_base(axes=[2, 3, 1, 4], dims=[3, 1]) diff --git a/test/onnx/mean_variance_norm_invalid_axes_test.onnx b/test/onnx/mean_variance_norm_invalid_axes_test.onnx new file mode 100644 index 00000000000..4fa86ca3dac --- /dev/null +++ b/test/onnx/mean_variance_norm_invalid_axes_test.onnx @@ -0,0 +1,12 @@ + $mean_variance_norm_invalid_axes_test: +9 +dataout"MeanVarianceNormalization* +axes@@@@$mean_variance_norm_invalid_axes_testZ +data +  + +b +out +  + +B \ No newline at end of file diff --git a/test/onnx/parse/mean_variance_norm.cpp b/test/onnx/parse/mean_variance_norm.cpp index 8a60cd950c3..8c3c17380ba 100644 --- a/test/onnx/parse/mean_variance_norm.cpp +++ b/test/onnx/parse/mean_variance_norm.cpp @@ -69,3 +69,8 @@ TEST_CASE(mean_variance_norm_invalid_type_test) { EXPECT(test::throws([&] { optimize_onnx("mean_variance_norm_invalid_type_test.onnx"); })); } + +TEST_CASE(mean_variance_norm_invalid_axes_test) +{ + EXPECT(test::throws([&] { optimize_onnx("mean_variance_norm_invalid_axes_test.onnx"); })); +} From 37654c34c7fae28bb479c4649aa45d6eeb7ebf0e Mon Sep 17 00:00:00 2001 From: gchinora Date: Mon, 7 Apr 2025 10:51:47 +0000 Subject: [PATCH 20/20] 'set epsilon according to input type' --- src/onnx/parse_mean_variance_norm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx/parse_mean_variance_norm.cpp b/src/onnx/parse_mean_variance_norm.cpp index a4ec778058f..f1a9716de4f 100644 --- a/src/onnx/parse_mean_variance_norm.cpp +++ b/src/onnx/parse_mean_variance_norm.cpp @@ -56,7 +56,7 @@ struct mean_variance_norm : op_parser const auto& x = args[0]; - const auto eps_default = 1e-7f; + const auto eps_default = (dtype == shape::half_type) ? 1e-7f : 1e-9f; const auto axes_default = std::vector{0, 2, 3}; auto eps = eps_default;