Implementing parsing of Mean Variance Normalization #202
Conversation
…r test, verify test and associated onnx files
| } | ||
| assert(X->get_shape().ndim() >= axes_min_size); | ||
|
|
||
| auto E_X = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), X); |
There was a problem hiding this comment.
Using snake case is correct, but you shouldn't use capital letters in local variable names
There was a problem hiding this comment.
Fair enough.... I was trying to confirm to the mathematical symbol of the expected value of a random variable. Which is E(X).
| axes.assign(info.attributes.at("axes").ints().begin(), info.attributes.at("axes").ints().end()); | ||
| axes_min_size = axes.size(); | ||
| } | ||
| assert(X->get_shape().ndim() >= axes_min_size); |
There was a problem hiding this comment.
In the parser you should throw an exception when a condition isn't met. You can do this using the MIGRAPHX_THROW macro which takes a string as input. There isn't a set format for the string, I usually go with "OperatorName: input_dim/attribute/anything has value x, should be y", something like you have above for the type validty check.
There was a problem hiding this comment.
Fair enough... note: see parse_instancenorm.cpp:88 there's also an assert(.) call.
|
|
||
| const auto& X = args[0]; | ||
|
|
||
| const auto eps_default = 1e-9f; |
There was a problem hiding this comment.
This isn't something that's pointed out in the docs, but 1e-9f wouldn't be the correct value to use for fp16, it's too small.
| struct mean_variance_norm : op_parser<mean_variance_norm> | ||
| { | ||
| std::set<shape::type_t> valid_types = { | ||
| shape::bf16_type, shape::double_type, shape::float_type}; |
There was a problem hiding this comment.
Missing fp16 from supported type list.
| @onnx_test() | ||
| def mean_variance_norm_val_test(): | ||
| # example from: https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/meanvariancenormalization.py | ||
| x = np.array( |
There was a problem hiding this comment.
Usually we want to avoid specifying tensor values in the graph, the instancenorm tests are an outlier in this regard.
By doing this you can reuse the same graph for the parse and verify tests.
You'd provide the values as an input to the compiled migraphx program.
Have a look at how gen_onnx.py:mod_test() does it, and its corresponding cpp file test/onnx/verify/mod_test.cpp
| node = onnx.helper.make_node('MeanVarianceNormalization', | ||
| inputs=['x'], | ||
| outputs=['y'], | ||
| axes=ax, |
There was a problem hiding this comment.
You should also add a test where you don't define the axes attribute to check if the parser correctly defaults it.
|
Negative parser tests should also be added. If there's a branch in your parser code that leads to an MIGRAPHX_THROW, it should be covered with a test. |
| mm->add_return({Y}); | ||
|
|
||
| migraphx::onnx_options options; | ||
| auto prog = read_onnx("mean_variance_norm_test.onnx", options); |
There was a problem hiding this comment.
Should use optimize_onnx to parse the model, read_onnx was used before so older tests still contain it.
| auto denominator= add_common_op(*mm, migraphx::make_op("add"), {std, eps_literal}); | ||
| auto Y = add_common_op(*mm, migraphx::make_op("div"), {numerator, denominator}); | ||
|
|
||
| mm->add_return({Y}); |
There was a problem hiding this comment.
Probably won't need the add_return when using optimize_onxx().
…onnx file; modify tensor shapes
… in the gen_onnx.py file
|
|
||
| const auto& x = args[0]; | ||
|
|
||
| const auto eps_default = 1e-7f; |
There was a problem hiding this comment.
You'd want eps_default to be 1e-7 if dtype is half_type, and 1e-9 for the other types.
| migraphx::program p; | ||
| const auto prog = optimize_onnx("mean_variance_norm_default_axes_test.onnx"); | ||
| const auto* mm = prog.get_main_module(); | ||
| const auto it = std::find_if(mm->begin(), mm->end(), [](auto instr){ return instr.name() == "reduce_mean";}); |
There was a problem hiding this comment.
I like the approach here. You'd still want to build out an entire "expected" graph for these tests as that's the approach in migraphx, but kudos for the way you did it. You don't need to make a change to this.
No description provided.