Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/literate/tutorials/example_synthetic_lstm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ out_lstm = train(
training_loss = :nseLoss,
loss_types = [:nse],
sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0),
plotting = false,
plotting = true,
show_progress = false,
input_batchnorm = false,
array_type = pref_array_type,
Expand Down
21 changes: 15 additions & 6 deletions ext/EasyHybridMakie.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Makie.Colors
using DataFrames
import Makie
import EasyHybrid
import EasyHybrid: get_loss_value
import EasyHybrid: get_loss_value, _get_target_y, _get_target_ŷ
using Statistics

include("HybridTheme.jl")
Expand Down Expand Up @@ -346,7 +346,7 @@ function EasyHybrid.train_board(
Label(gd_tm[i][1, 1:2, Top()], "$(t)")

Makie.scatter!(ax_tr, p_tr_sub, o_tr_sub; color = :grey25, alpha = 0.6, markersize = 6)
Makie.lines!(ax_tr, sort(o_tr), sort(o_tr); color = :black, linestyle = :dash)
# Makie.lines!(ax_tr, sort(o_tr), sort(o_tr); color = :black, linestyle = :dash)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 1:1 reference line for the scatter plot has been commented out. This line is very useful for visualizing model performance. I recommend adding it back. The original implementation sort(o_tr), sort(o_tr) was not ideal. A better approach is to use the mn and mx values that are already calculated to draw the line across the full range of the axes.

        Makie.lines!(ax_tr, [mn, mx], [mn, mx]; color = :black, linestyle = :dash)

# Validation scatter plot
ax_val = Makie.Axis(
gd_tm[i][1, 2]; aspect = 1, xlabel = "Predicted", ylabel = "",
Expand All @@ -371,7 +371,7 @@ function EasyHybrid.train_board(
o_val_sub = @lift(o_val[$val_idx])

Makie.scatter!(ax_val, p_val_sub, o_val_sub; color = :tomato, alpha = 0.6, markersize = 6)
Makie.lines!(ax_val, sort(o_val), sort(o_val); color = :black, linestyle = :dash)
# Makie.lines!(ax_val, sort(o_val), sort(o_val); color = :black, linestyle = :dash)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 1:1 reference line for the validation scatter plot has been commented out. For consistency with the training plot and to better assess performance, I recommend adding it back using the same mn and mx range as the training plot.

        Makie.lines!(ax_val, [mn, mx], [mn, mx]; color = :black, linestyle = :dash)

end
Label(gd_tm[1][1:end, 0], "Observed", tellheight = false, rotation = pi / 2)
# Label(gd_tm[end+1,1:end], "Predicted")
Expand Down Expand Up @@ -433,7 +433,9 @@ function EasyHybrid.update_plotting_observables(
current_ŷ_val,
target_names,
epoch;
monitor_names
monitor_names,
y_train,
y_val
)

l_value = get_loss_value(l_train, training_loss, Symbol("$agg"))
Expand All @@ -447,8 +449,15 @@ function EasyHybrid.update_plotting_observables(

for t in target_names
# replace the array stored in the Observable:
train_preds[t][] = vec(getfield(current_ŷ_train, t))
val_preds[t][] = vec(getfield(current_ŷ_val, t))
current_y_train = _get_target_y(y_train, t)
current_y_train_hat = _get_target_ŷ(current_ŷ_train, current_y_train, t)

current_y_val = _get_target_y(y_val, t)
current_y_val_hat = _get_target_ŷ(current_ŷ_val, current_y_val, t)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is an inconsistent use of the function name here. The function is imported as _get_target_ŷ on line 9 and used with that name on line 453. However, here it is called as _get_target_ŷ. This will cause an UndefVarError. Please use the consistent name _get_target_ŷ.

        current_y_val_hat = _get_target_ŷ(current_ŷ_val, current_y_val, t)


train_preds[t][] = current_y_train_hat
val_preds[t][] = current_y_val_hat

# and notify Makie that it changed:
notify(train_preds[t])
notify(val_preds[t])
Expand Down
34 changes: 30 additions & 4 deletions src/utils/plotrecipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ function initialize_plotting_observables(init_ŷ_train, init_ŷ_val, y_train,
val_h_obs = to_obs([p_val])

# build NamedTuples of Observables for preds and obs
train_preds = to_obs_tuple(init_ŷ_train, target_names)
val_preds = to_obs_tuple(init_ŷ_val, target_names)
train_obs = map(Array, y_train)
val_obs = map(Array, y_val)
train_preds, train_obs = to_obs_tuple(init_ŷ_train, y_train, target_names)
val_preds, val_obs = to_obs_tuple(init_ŷ_val, y_val, target_names)

# train_preds = to_obs_tuple(init_ŷ_train, target_names)
# val_preds = to_obs_tuple(init_ŷ_val, target_names)
# train_obs = to_tuple(y_train, target_names)
# val_obs = to_tuple(y_val, target_names)
Comment on lines +65 to +68
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code appears to be a remnant of a refactoring. It should be removed to improve code clarity.


# --- monitored parameters/state as Observables ---
train_monitor = !isempty(monitor_names) ? monitor_to_obs(init_ŷ_train, monitor_names) : nothing
Expand Down Expand Up @@ -113,6 +116,29 @@ function to_obs_tuple(y, target_names)
return (; (t => to_obs(vec(getfield(y, t))) for t in target_names)...)
end

function to_obs_tuple(ŷ, y, target_names)
# first get observations, they could have a time dimension
tmp_obs = []
tmp_pred = []
for t in target_names
y_ = _get_target_y(y, t)
ŷ_ = _get_target_ŷ(ŷ, y_, t) # this is to match the shape of the observations, which could have a time dimension
push!(tmp_obs, t => vec(y_),)
push!(tmp_pred, t => to_obs(vec(ŷ_,)))
end
out_pred = (; tmp_pred...)
out_obs = (; tmp_obs...)
return out_pred, out_obs
end
Comment on lines +119 to +132
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function can be written more idiomatically and concisely using generator expressions to construct the named tuples directly, instead of using temporary arrays and a loop. This avoids intermediate allocations and improves readability.

function to_obs_tuple(ŷ, y, target_names)
    # first get observations, they could have a time dimension
    out_pred = (; (t => to_obs(_get_target_ŷ(ŷ, _get_target_y(y, t), t)) for t in target_names)...)
    out_obs = (; (t => _get_target_y(y, t) for t in target_names)...)
    return out_pred, out_obs
end


function to_tuple(y::KeyedArray, target_names)
return (; (t => y(variable = t) for t in target_names)...) # observations are fixed, no Observables are needed!
end

function to_tuple(y::AbstractDimArray, target_names)
return (; (t => Array(y[variable = At(t)]) for t in target_names)...) # observations are fixed, no Observables are needed!
end

function monitor_to_obs(ŷ, monitor_names; cuts = (0.25, 0.5, 0.75))
return (;
(
Expand Down
Loading