diff --git a/activations_plus/sparsemax/sparsemax_func.py b/activations_plus/sparsemax/sparsemax_func.py index bed9c35..55ea66a 100644 --- a/activations_plus/sparsemax/sparsemax_func.py +++ b/activations_plus/sparsemax/sparsemax_func.py @@ -58,25 +58,7 @@ def forward(ctx: Any, x: torch.Tensor, dim: int = -1) -> torch.Tensor: # Translate by max for numerical stability x = x - x.max(dim=reduce_dim, keepdim=True).values.expand_as(x) - zs = x.sort(dim=reduce_dim, descending=True).values - d = x.size(reduce_dim) - range_th = torch.arange(1, d + 1, device=x.device, dtype=x.dtype) - shape = [1] * x.dim() - shape[reduce_dim] = d - range_th = range_th.view(*shape).expand_as(x) - - # Determine sparsity of projection - bound = 1 + range_th * zs - cumsum_zs = zs.cumsum(dim=reduce_dim) - is_gt = bound.gt(cumsum_zs).type(x.dtype) - k = (is_gt * range_th).max(dim=reduce_dim, keepdim=True).values - - # Compute threshold - zs_sparse = is_gt * zs - taus = (zs_sparse.sum(dim=reduce_dim, keepdim=True) - 1) / k - taus = taus.expand_as(x) - - output = torch.max(torch.zeros_like(x), x - taus) + output, ctx = SparsemaxFunction._threshold_and_support(ctx, x, reduce_dim) # Save context ctx.save_for_backward(output) @@ -116,12 +98,78 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: else: reduce_dim = ctx.dim - nonzeros = torch.ne(output, 0) - num_nonzeros = nonzeros.sum(dim=reduce_dim, keepdim=True) - sum_all = (grad_output * nonzeros).sum(dim=reduce_dim, keepdim=True) / num_nonzeros - grad_input = nonzeros * (grad_output - sum_all.expand_as(grad_output)) + grad_input = SparsemaxFunction._compute_gradient(ctx, grad_output, output, reduce_dim) if ctx.needs_reshaping: ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input) return grad_input, None + + @staticmethod + def _threshold_and_support(ctx: Any, x: torch.Tensor, reduce_dim: int) -> tuple[torch.Tensor, Any]: + """Compute the threshold and support for the input tensor. + + Parameters + ---------- + ctx : Any + Context object for autograd. + x : torch.Tensor + Input tensor. + reduce_dim : int + Dimension along which to compute threshold/support. + + Returns + ------- + tuple[torch.Tensor, Any] + The output tensor after applying Sparsemax and the updated context. + + """ + zs = x.sort(dim=reduce_dim, descending=True).values + d = x.size(reduce_dim) + range_th = torch.arange(1, d + 1, device=x.device, dtype=x.dtype) + shape = [1] * x.dim() + shape[reduce_dim] = d + range_th = range_th.view(*shape).expand_as(x) + + # Determine sparsity of projection + bound = 1 + range_th * zs + cumsum_zs = zs.cumsum(dim=reduce_dim) + is_gt = bound.gt(cumsum_zs).type(x.dtype) + k = (is_gt * range_th).max(dim=reduce_dim, keepdim=True).values + + # Compute threshold + zs_sparse = is_gt * zs + taus = (zs_sparse.sum(dim=reduce_dim, keepdim=True) - 1) / k + taus = taus.expand_as(x) + + output = torch.max(torch.zeros_like(x), x - taus) + + return output, ctx + + @staticmethod + def _compute_gradient(ctx: Any, grad_output: torch.Tensor, output: torch.Tensor, reduce_dim: int) -> torch.Tensor: + """Compute the gradient for the backward pass. + + Parameters + ---------- + ctx : Any + Context object for autograd. + grad_output : torch.Tensor + Gradient of the loss with respect to the output. + output : torch.Tensor + Output tensor from the forward pass. + reduce_dim : int + Dimension along which to compute the gradient. + + Returns + ------- + torch.Tensor + The gradient with respect to the input. + + """ + nonzeros = torch.ne(output, 0) + num_nonzeros = nonzeros.sum(dim=reduce_dim, keepdim=True) + sum_all = (grad_output * nonzeros).sum(dim=reduce_dim, keepdim=True) / num_nonzeros + grad_input = nonzeros * (grad_output - sum_all.expand_as(grad_output)) + + return grad_input diff --git a/tests/sparsemax/test_sparsemax_func.py b/tests/sparsemax/test_sparsemax_func.py index e40d289..25ee09c 100644 --- a/tests/sparsemax/test_sparsemax_func.py +++ b/tests/sparsemax/test_sparsemax_func.py @@ -2,6 +2,7 @@ import torch from activations_plus.sparsemax import SparsemaxFunction +from activations_plus.sparsemax.utils import flatten_all_but_nth_dim, unflatten_all_but_nth_dim def test_sparsemax_forward_valid_input(): @@ -115,3 +116,42 @@ def test_sparsemax_backward_parametrized(x): assert torch.allclose(grad_sum, torch.zeros_like(grad_sum), atol=1e-5), ( "Gradients should sum to zero along the projection dimension" ) + + +def test_flatten_all_but_nth_dim(): + x = torch.randn(2, 3, 4, 5) + ctx = type('', (), {})() # Create an empty context object + ctx.dim = 1 + ctx, flattened_x = flatten_all_but_nth_dim(ctx, x) + assert flattened_x.shape == (3, 40), "Flattened shape is incorrect" + assert ctx.original_size == x.size(), "Original size not saved correctly in context" + + +def test_unflatten_all_but_nth_dim(): + x = torch.randn(3, 40) + ctx = type('', (), {})() # Create an empty context object + ctx.dim = 1 + ctx.original_size = (2, 3, 4, 5) + ctx, unflattened_x = unflatten_all_but_nth_dim(ctx, x) + assert unflattened_x.shape == (2, 3, 4, 5), "Unflattened shape is incorrect" + + +def test_threshold_and_support(): + x = torch.tensor([[1.0, 2.0, 3.0], [0.5, 0.5, 0.5]], dtype=torch.float32) + ctx = type('', (), {})() # Create an empty context object + ctx.dim = 1 + output, ctx = SparsemaxFunction._threshold_and_support(ctx, x, ctx.dim) + assert output is not None, "Output should not be None" + assert output.shape == x.shape, "Output shape should match input shape" + assert torch.all(output >= 0), "Output should have non-negative values" + + +def test_compute_gradient(): + grad_output = torch.tensor([[1.0, 2.0, 3.0], [0.5, 0.5, 0.5]], dtype=torch.float32) + output = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.1, 0.3]], dtype=torch.float32) + ctx = type('', (), {})() # Create an empty context object + ctx.dim = 1 + grad_input = SparsemaxFunction._compute_gradient(ctx, grad_output, output, ctx.dim) + assert grad_input is not None, "Gradient input should not be None" + assert grad_input.shape == grad_output.shape, "Gradient input shape should match grad output shape" + assert torch.all(torch.isfinite(grad_input)), "Gradient input should not contain NaN or Inf" diff --git a/tests/sparsemax/test_sparsemax_pb.py b/tests/sparsemax/test_sparsemax_pb.py index 4981132..d09ad9d 100644 --- a/tests/sparsemax/test_sparsemax_pb.py +++ b/tests/sparsemax/test_sparsemax_pb.py @@ -91,7 +91,6 @@ def test_sparsemax_backward_pb(random_data, dim): ), dim=st.integers(min_value=-1, max_value=0), ) -@pytest.mark.skip def test_sparsemax_v2_threshold_and_support(random_data, dim): x = torch.tensor(random_data, dtype=torch.double) tau, support_size = SparsemaxFunction._threshold_and_support(x, dim=dim) @@ -112,7 +111,6 @@ def test_sparsemax_v2_threshold_and_support(random_data, dim): elements=st.floats(min_value=-1000, max_value=1000, allow_nan=False, allow_infinity=False), ), ) -@pytest.mark.skip def test_compare_with_original(random_data): x = torch.tensor(random_data, dtype=torch.double, requires_grad=True) for dim in range(-1, x.dim()):