Skip to content

Add nn loss adapter support in v2 BaseModel with multi-target handling#2238

Open
adityashinde0 wants to merge 8 commits intosktime:mainfrom
adityashinde0:nn-loss-adapter-v2
Open

Add nn loss adapter support in v2 BaseModel with multi-target handling#2238
adityashinde0 wants to merge 8 commits intosktime:mainfrom
adityashinde0:nn-loss-adapter-v2

Conversation

@adityashinde0
Copy link
Copy Markdown

@adityashinde0 adityashinde0 commented Mar 27, 2026

LLM generated content, by GPT-5

Reference Issues/PRs

Fixes #1970.

What does this implement/fix? Explain your changes.

This PR adds support for using torch.nn losses in ptf-v2 while keeping the existing model output contract unchanged.

Changes made:

  • Added an adapter wrapper class:
    pytorch_forecasting/models/base/_loss_adapter_v2.py
    • NNLossAdapter handles shape conversion for nn losses internally.
    • Supports common same-shape losses (e.g., MSELoss, L1Loss) and class-index style losses (e.g., CrossEntropyLoss, NLLLoss).
    • Provides to_prediction and to_quantiles fallbacks to stay compatible with v2 model APIs.
  • Updated v2 base model loss initialization in
    pytorch_forecasting/models/base/_base_model_v2.py
    • Allows users to pass nn.Module losses directly.
    • Under the hood, nn losses are automatically wrapped with NNLossAdapter.
    • Existing Metric/MultiLoss usage remains supported.
  • Updated v2 loss computation path in _compute_loss
    • Handles single-target unchanged.
    • Handles multi-target:
      • Uses MultiLoss when provided.
      • For a single wrapped nn loss, applies loss per target tensor and aggregates (mean), avoiding passing list directly to nn losses.

This addresses the issue where nn losses could not handle list targets and where prediction/target shape expectations differ across loss types.

What should a reviewer concentrate their feedback on?

  • NNLossAdapter behavior and shape handling logic for:
    • same-shape losses
    • class-index losses
  • Backward compatibility for existing Metric/MultiLoss flow in v2
  • Multi-target behavior in _compute_loss (especially tensor slicing and aggregation strategy)
  • Whether to_prediction / to_quantiles fallbacks are appropriate for nn-loss-backed models

Did you add any tests for the change?

I validated via local targeted model test runs:

  • tests/test_models/test_tft_v2.py
  • tests/test_models/test_timexer_v2.py

If maintainers prefer, I can add dedicated unit tests specifically for NNLossAdapter and multi-target nn-loss behavior in follow-up.

Any other comments?

  • There is also a change in pytorch_forecasting/models/samformer/_samformer_v2.py in my working tree. Please let me know if you prefer this in a separate PR to keep [ENH] Add support for nn losses to ptf-v2 #1970 strictly scoped.
  • Happy to revise aggregation policy (mean vs weighted) for multi-target nn losses if project conventions require a different strategy.

PR checklist

  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG].
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

@phoeenniixx phoeenniixx added the AI overuse suspected Overuse of AI without understanding what is happening in the code label Mar 27, 2026
@adityashinde0
Copy link
Copy Markdown
Author

hello , Anyone can help , this is my first PR ...

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 28, 2026

Codecov Report

❌ Patch coverage is 33.33333% with 46 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@17c51ba). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...ytorch_forecasting/models/base/_loss_adapter_v2.py 27.27% 32 Missing ⚠️
pytorch_forecasting/models/base/_base_model_v2.py 44.00% 14 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2238   +/-   ##
=======================================
  Coverage        ?   86.32%           
=======================================
  Files           ?      167           
  Lines           ?     9819           
  Branches        ?        0           
=======================================
  Hits            ?     8476           
  Misses          ?     1343           
  Partials        ?        0           
Flag Coverage Δ
cpu 86.32% <33.33%> (?)
pytest 86.32% <33.33%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@adityashinde0
Copy link
Copy Markdown
Author

I think i am done any improvement needed @phoeenniixx @PranavBhatP

Copy link
Copy Markdown
Member

@phoeenniixx phoeenniixx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of files that should not be changed for this PR. Please revert those changes and when you are using AI, please use it carefully, and make sure that you are able to understand all the changes it is making.

@adityashinde0
Copy link
Copy Markdown
Author

There are a lot of files that should not be changed for this PR. Please revert those changes and when you are using AI, please use it carefully, and make sure that you are able to understand all the changes it is making.

Thanks !!!
for guiding me this is my first PR and I am try to understand all process , Any reference can you provide me to gain more knowledge in contribution and issue resolving by understanding code ..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AI overuse suspected Overuse of AI without understanding what is happening in the code

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ENH] Add support for nn losses to ptf-v2

2 participants