Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions references/wifi_densepose_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,57 @@
from typing import Dict, List, Tuple, Optional
from collections import OrderedDict

class RestrictedUnpickler(pickle.Unpickler):
"""
Restricted unpickler that only allows safe PyTorch and NumPy types.
Prevents arbitrary code execution during model loading.
Based on Trail of Bits recommendations for safe deserialization.
"""
def find_class(self, module, name):
# Only allow specific safe modules and classes
safe_modules = {
'torch', 'torch.nn', 'torch.optim', 'torch._utils',
'torch.storage', 'numpy', 'numpy.core.multiarray',
'collections', '__builtin__', 'builtins'
}

# Check if module is in safe list or is a submodule of a safe module
if module in safe_modules or any(module.startswith(safe + '.') for safe in safe_modules):
return super().find_class(module, name)

# Reject everything else to prevent arbitrary code execution
raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden for security reasons")

def safe_torch_load(path, map_location=None):
"""
Safely load a PyTorch checkpoint using restricted unpickling.
This prevents arbitrary code execution vulnerabilities in pickle.

Args:
path: Path to the checkpoint file
map_location: Device to map the loaded tensors to

Returns:
Loaded checkpoint dictionary
"""
with open(path, 'rb') as f:
# Use restricted unpickler instead of default pickle
return RestrictedUnpickler(f).load()

def safe_torch_save(obj, path):
"""
Safely save a PyTorch checkpoint.
While this still uses pickle internally, combined with safe_torch_load()
it provides a complete safe serialization pipeline.

Args:
obj: Object to save (typically a dict with state_dicts)
path: Path where to save the checkpoint
"""
# Save using standard pickle, but ensure loading is done safely
with open(path, 'wb') as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

class CSIPhaseProcessor:
"""
Processes raw CSI phase data through unwrapping, filtering, and linear fitting
Expand Down Expand Up @@ -435,13 +486,16 @@ def train_step(self, amplitude_data, phase_data, targets):
return loss.item(), loss_dict

def save_model(self, path):
torch.save({
# Use safe saving function to avoid torch.save which relies on pickle
safe_torch_save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)

def load_model(self, path):
checkpoint = torch.load(path)
# Use safe loading with RestrictedUnpickler to prevent arbitrary code execution
# This whitelists only safe PyTorch/NumPy classes and blocks malicious code
checkpoint = safe_torch_load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Expand Down