-
Notifications
You must be signed in to change notification settings - Fork 6
dashboard LSTMs outputs #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b8e5591
668749c
6eb9e00
19848ac
653e537
0814723
1e5a4ed
af0da1d
9b50c6a
ad4dfbe
a833485
a6ffd36
3840511
4c7a0b6
8defc0f
c65d393
a480d5e
14d83ed
6ca9110
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ using Makie.Colors | |
| using DataFrames | ||
| import Makie | ||
| import EasyHybrid | ||
| import EasyHybrid: get_loss_value | ||
| import EasyHybrid: get_loss_value, _get_target_y, _get_target_ŷ | ||
| using Statistics | ||
|
|
||
| include("HybridTheme.jl") | ||
|
|
@@ -346,7 +346,7 @@ function EasyHybrid.train_board( | |
| Label(gd_tm[i][1, 1:2, Top()], "$(t)") | ||
|
|
||
| Makie.scatter!(ax_tr, p_tr_sub, o_tr_sub; color = :grey25, alpha = 0.6, markersize = 6) | ||
| Makie.lines!(ax_tr, sort(o_tr), sort(o_tr); color = :black, linestyle = :dash) | ||
| # Makie.lines!(ax_tr, sort(o_tr), sort(o_tr); color = :black, linestyle = :dash) | ||
| # Validation scatter plot | ||
| ax_val = Makie.Axis( | ||
| gd_tm[i][1, 2]; aspect = 1, xlabel = "Predicted", ylabel = "", | ||
|
|
@@ -371,7 +371,7 @@ function EasyHybrid.train_board( | |
| o_val_sub = @lift(o_val[$val_idx]) | ||
|
|
||
| Makie.scatter!(ax_val, p_val_sub, o_val_sub; color = :tomato, alpha = 0.6, markersize = 6) | ||
| Makie.lines!(ax_val, sort(o_val), sort(o_val); color = :black, linestyle = :dash) | ||
| # Makie.lines!(ax_val, sort(o_val), sort(o_val); color = :black, linestyle = :dash) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 1:1 reference line for the validation scatter plot has been commented out. For consistency with the training plot and to better assess performance, I recommend adding it back using the same Makie.lines!(ax_val, [mn, mx], [mn, mx]; color = :black, linestyle = :dash) |
||
| end | ||
| Label(gd_tm[1][1:end, 0], "Observed", tellheight = false, rotation = pi / 2) | ||
| # Label(gd_tm[end+1,1:end], "Predicted") | ||
|
|
@@ -433,7 +433,9 @@ function EasyHybrid.update_plotting_observables( | |
| current_ŷ_val, | ||
| target_names, | ||
| epoch; | ||
| monitor_names | ||
| monitor_names, | ||
| y_train, | ||
| y_val | ||
| ) | ||
|
|
||
| l_value = get_loss_value(l_train, training_loss, Symbol("$agg")) | ||
|
|
@@ -447,8 +449,15 @@ function EasyHybrid.update_plotting_observables( | |
|
|
||
| for t in target_names | ||
| # replace the array stored in the Observable: | ||
| train_preds[t][] = vec(getfield(current_ŷ_train, t)) | ||
| val_preds[t][] = vec(getfield(current_ŷ_val, t)) | ||
| current_y_train = _get_target_y(y_train, t) | ||
| current_y_train_hat = _get_target_ŷ(current_ŷ_train, current_y_train, t) | ||
|
|
||
| current_y_val = _get_target_y(y_val, t) | ||
| current_y_val_hat = _get_target_ŷ(current_ŷ_val, current_y_val, t) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an inconsistent use of the function name here. The function is imported as current_y_val_hat = _get_target_ŷ(current_ŷ_val, current_y_val, t) |
||
|
|
||
| train_preds[t][] = current_y_train_hat | ||
| val_preds[t][] = current_y_val_hat | ||
|
|
||
| # and notify Makie that it changed: | ||
| notify(train_preds[t]) | ||
| notify(val_preds[t]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,10 +59,13 @@ function initialize_plotting_observables(init_ŷ_train, init_ŷ_val, y_train, | |
| val_h_obs = to_obs([p_val]) | ||
|
|
||
| # build NamedTuples of Observables for preds and obs | ||
| train_preds = to_obs_tuple(init_ŷ_train, target_names) | ||
| val_preds = to_obs_tuple(init_ŷ_val, target_names) | ||
| train_obs = map(Array, y_train) | ||
| val_obs = map(Array, y_val) | ||
| train_preds, train_obs = to_obs_tuple(init_ŷ_train, y_train, target_names) | ||
| val_preds, val_obs = to_obs_tuple(init_ŷ_val, y_val, target_names) | ||
|
|
||
| # train_preds = to_obs_tuple(init_ŷ_train, target_names) | ||
| # val_preds = to_obs_tuple(init_ŷ_val, target_names) | ||
| # train_obs = to_tuple(y_train, target_names) | ||
| # val_obs = to_tuple(y_val, target_names) | ||
|
Comment on lines
+65
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # --- monitored parameters/state as Observables --- | ||
| train_monitor = !isempty(monitor_names) ? monitor_to_obs(init_ŷ_train, monitor_names) : nothing | ||
|
|
@@ -113,6 +116,29 @@ function to_obs_tuple(y, target_names) | |
| return (; (t => to_obs(vec(getfield(y, t))) for t in target_names)...) | ||
| end | ||
|
|
||
| function to_obs_tuple(ŷ, y, target_names) | ||
| # first get observations, they could have a time dimension | ||
| tmp_obs = [] | ||
| tmp_pred = [] | ||
| for t in target_names | ||
| y_ = _get_target_y(y, t) | ||
| ŷ_ = _get_target_ŷ(ŷ, y_, t) # this is to match the shape of the observations, which could have a time dimension | ||
| push!(tmp_obs, t => vec(y_),) | ||
| push!(tmp_pred, t => to_obs(vec(ŷ_,))) | ||
| end | ||
| out_pred = (; tmp_pred...) | ||
| out_obs = (; tmp_obs...) | ||
| return out_pred, out_obs | ||
| end | ||
|
Comment on lines
+119
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be written more idiomatically and concisely using generator expressions to construct the named tuples directly, instead of using temporary arrays and a loop. This avoids intermediate allocations and improves readability. function to_obs_tuple(ŷ, y, target_names)
# first get observations, they could have a time dimension
out_pred = (; (t => to_obs(_get_target_ŷ(ŷ, _get_target_y(y, t), t)) for t in target_names)...)
out_obs = (; (t => _get_target_y(y, t) for t in target_names)...)
return out_pred, out_obs
end |
||
|
|
||
| function to_tuple(y::KeyedArray, target_names) | ||
| return (; (t => y(variable = t) for t in target_names)...) # observations are fixed, no Observables are needed! | ||
| end | ||
|
|
||
| function to_tuple(y::AbstractDimArray, target_names) | ||
| return (; (t => Array(y[variable = At(t)]) for t in target_names)...) # observations are fixed, no Observables are needed! | ||
| end | ||
|
|
||
| function monitor_to_obs(ŷ, monitor_names; cuts = (0.25, 0.5, 0.75)) | ||
| return (; | ||
| ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 1:1 reference line for the scatter plot has been commented out. This line is very useful for visualizing model performance. I recommend adding it back. The original implementation
sort(o_tr), sort(o_tr)was not ideal. A better approach is to use themnandmxvalues that are already calculated to draw the line across the full range of the axes.