Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ model:
d_ff: 1024
n_layers: 3
dropout: 0.1
k_periods: 3
k_periods: 2
min_period_threshold: 7 # 최소 주기 하한선
kernel_set:
- [3, 3]
Expand Down
33 changes: 24 additions & 9 deletions src/timesnet_forecast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
else:
self.M = valid_mask.astype(np.float32)
self.T, self.N = self.X.shape
if self.N <= 0:
raise ValueError("wide_values must contain at least one series column")
self.L = int(input_len)
if mode == "direct":
self.H = int(pred_len)
Expand Down Expand Up @@ -157,32 +159,45 @@ def __init__(
else:
self.series_ids = None

self._windows_per_series = int(len(self.idxs))

def __len__(self) -> int:
return int(len(self.idxs))
return int(self._windows_per_series * self.N)

def __getitem__(self, idx: int) -> tuple[object, ...]:
s = int(self.idxs[idx])
if self._windows_per_series <= 0:
raise IndexError("SlidingWindowDataset is empty")
window_idx = int(idx // self.N)
series_idx = int(idx % self.N)
if window_idx >= self._windows_per_series:
raise IndexError("index out of range for sliding windows")
s = int(self.idxs[window_idx])
if self.time_shift > 0:
delta = np.random.randint(-self.time_shift, self.time_shift + 1)
s = int(np.clip(s + delta, 0, self.T - self.L - self.H))
e = s + self.L
x_tensor = self._X_tensor[s:e, :].clone()
x_slice = self._X_tensor[s:e, series_idx]
if self.add_noise_std > 0:
x_tensor = x_slice.clone().unsqueeze(-1)
noise = torch.randn_like(x_tensor) * self.add_noise_std
x_tensor = x_tensor + noise
y_tensor = self._X_tensor[e : e + self.H, :].clone()
mask_tensor = self._M_tensor[e : e + self.H, :].clone()
else:
x_tensor = x_slice.unsqueeze(-1)
y_tensor = self._X_tensor[e : e + self.H, series_idx].unsqueeze(-1)
mask_tensor = self._M_tensor[e : e + self.H, series_idx].unsqueeze(-1)
if self.time_marks is not None:
x_mark = self.time_marks[s:e, :].clone()
y_mark = self.time_marks[e : e + self.H, :].clone()
x_mark = self.time_marks[s:e, :]
y_mark = self.time_marks[e : e + self.H, :]
else:
x_mark = self._empty_time_mark
y_mark = self._empty_time_mark
items: list[object] = [x_tensor, y_tensor, mask_tensor, x_mark, y_mark]
if self.series_static is not None:
items.append(self.series_static)
static_slice = self.series_static[series_idx : series_idx + 1, :]
items.append(static_slice)
if self.series_ids is not None:
items.append(self.series_ids)
id_slice = self.series_ids[series_idx : series_idx + 1]
items.append(id_slice)
return tuple(items)

@staticmethod
Expand Down
48 changes: 32 additions & 16 deletions src/timesnet_forecast/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
import torch


def negative_binomial_mask(
y: torch.Tensor,
rate: torch.Tensor,
dispersion: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute a boolean mask for valid NB likelihood elements."""

finite_mask = torch.isfinite(y) & torch.isfinite(rate) & torch.isfinite(dispersion)
if mask is not None:
mask_bool = mask.to(dtype=torch.bool)
if mask_bool.shape != finite_mask.shape:
mask_bool = mask_bool.expand_as(finite_mask)
finite_mask = finite_mask & mask_bool
return finite_mask


def negative_binomial_nll(
y: torch.Tensor,
rate: torch.Tensor,
Expand All @@ -13,26 +30,25 @@ def negative_binomial_nll(
"""Negative binomial negative log-likelihood averaged over valid elements."""

dtype = torch.float32
y = y.to(dtype)
y = torch.clamp(y.to(dtype), min=0.0)
rate = rate.to(dtype)
dispersion = dispersion.to(dtype)

alpha = torch.clamp(dispersion, min=eps)
mu = torch.clamp(rate, min=eps)
r = 1.0 / alpha
log_p = torch.log(r) - torch.log(r + mu)
log1m_p = torch.log(mu) - torch.log(r + mu)
log_prob = (
torch.lgamma(y + r)
- torch.lgamma(r)
log1p_alpha_mu = torch.log1p(alpha * mu)
log_alpha = torch.log(alpha)
log_mu = torch.log(mu)
inv_alpha = torch.reciprocal(alpha)
ll = (
torch.lgamma(y + inv_alpha)
- torch.lgamma(inv_alpha)
- torch.lgamma(y + 1.0)
+ r * log_p
+ y * log1m_p
+ inv_alpha * (-log1p_alpha_mu)
+ y * (log_alpha + log_mu - log1p_alpha_mu)
Comment on lines 37 to +48
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Restore missing log(alpha) term in NB likelihood

The refactored negative_binomial_nll rewrites the log‑likelihood but omits the -(1/alpha) * log(alpha) component that was implicitly present before through r * log_p. The new implementation only subtracts inv_alpha * log1p(alpha*mu) and thus produces a value that differs from the standard negative‑binomial log‑likelihood for any dispersion ≠ 1. This changes both the loss scale and the gradient with respect to dispersion, so training will optimize the wrong objective and can no longer match the behaviour of the previous formulation. Reintroducing the -inv_alpha * log_alpha term before masking fixes the derivation.

Useful? React with 👍 / 👎.

)
if mask is not None:
mask = mask.to(dtype)
log_prob = log_prob * mask
denom = torch.clamp(mask.sum(), min=1.0)
else:
denom = log_prob.numel()
return -(log_prob.sum() / denom)

valid_mask = negative_binomial_mask(y, mu, alpha, mask)
weight = valid_mask.to(dtype)
denom = torch.clamp(weight.sum(), min=1.0)
return -(ll * weight).sum() / denom
46 changes: 30 additions & 16 deletions src/timesnet_forecast/models/timesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,34 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device)
return empty_idx, empty_amp.to(dtype)

freq_indices = torch.arange(amp_mean.numel(), device=device, dtype=dtype)
tie_break = freq_indices * torch.finfo(dtype).eps
scores = amp_mean - tie_break
freq_indices = torch.arange(amp_mean.numel(), device=device, dtype=torch.long)
log_indices = torch.log1p(freq_indices.to(torch.float32))
scores = amp_mean - 1e-8 * log_indices.to(dtype)
_, indices = torch.topk(scores, k=k, largest=True)
safe_indices = indices.to(device=device, dtype=torch.long).clamp_min(1)
sample_values = amp_samples.gather(
1, safe_indices.view(1, -1).expand(B, -1)
)

L_t = torch.tensor(L, dtype=torch.long, device=device)
upper_bound = min(self.pmax, max(L - 1, self.min_period_threshold))
if upper_bound < self.min_period_threshold:
empty_idx = torch.zeros(0, dtype=torch.long, device=device)
empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device)
return empty_idx, empty_amp.to(dtype)

periods = (L_t + safe_indices - 1) // safe_indices
periods = torch.clamp(
periods,
min=self.min_period_threshold,
max=self.pmax,
)
periods = torch.clamp(periods, min=self.min_period_threshold, max=upper_bound)

cycles = (L_t + periods - 1) // periods
valid_mask = cycles >= 2
if not torch.any(valid_mask):
empty_idx = torch.zeros(0, dtype=torch.long, device=device)
empty_amp = torch.zeros(B, 0, dtype=amp_samples.dtype, device=device)
return empty_idx, empty_amp.to(dtype)

periods = periods[valid_mask]
sample_values = sample_values[:, valid_mask]

return periods, sample_values.to(dtype)

Expand Down Expand Up @@ -280,6 +292,7 @@ def __init__(
# ``period_selector`` is injected from ``TimesNet`` after instantiation to
# avoid registering the shared selector multiple times.
self.period_selector: FFTPeriodSelector | None = None
self._period_calls: int = 0

def _build_layers(self, channels: int, device: torch.device, dtype: torch.dtype) -> None:
if channels <= 0:
Expand Down Expand Up @@ -322,6 +335,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.period_selector is None:
raise RuntimeError("TimesBlock.period_selector has not been set")

self._period_calls = getattr(self, "_period_calls", 0) + 1
if self.inception is None:
if self._configured_d_model is not None and x.size(-1) != self._configured_d_model:
raise ValueError(
Expand Down Expand Up @@ -1019,14 +1033,14 @@ def _ensure_embedding(
)
self.pre_embedding_dropout = self.pre_embedding_dropout.to(device=x.device)

if (
isinstance(self.min_sigma_vector, torch.Tensor)
and self.min_sigma_vector.numel() > 0
and self.min_sigma_vector.shape[-1] != c_in
):
raise ValueError(
"min_sigma_vector length does not match number of series"
)
if isinstance(self.min_sigma_vector, torch.Tensor) and self.min_sigma_vector.numel() > 0:
current = int(self.min_sigma_vector.shape[-1])
if current < c_in:
raise ValueError(
"min_sigma_vector length does not match number of series"
)
if current != c_in:
self.min_sigma_vector = self.min_sigma_vector[..., :c_in]

if self.embedding_time_features is not None and self.embedding_time_features != time_dim:
raise ValueError("Temporal feature dimension changed between calls")
Expand Down
Loading