diff --git a/experiments/semantic_segmentation/generalists/check_sampling.py b/experiments/semantic_segmentation/generalists/check_sampling.py new file mode 100644 index 0000000..3c2428c --- /dev/null +++ b/experiments/semantic_segmentation/generalists/check_sampling.py @@ -0,0 +1,223 @@ +import os +import numpy as np +import torch +from tqdm import tqdm +import pandas as pd +from functools import partial +from torch_em.data.datasets.histopathology.pannuke import get_pannuke_dataset +import micro_sam.training as sam_training +import hashlib +from torch.utils.data import RandomSampler +from patho_sam.training.util import get_sampler, build_transforms, geometric_transforms, photometric_transforms +from torch_em.data import MinTwoInstanceSampler +import imageio.v3 as imageio + + +PATH = os.environ.get("WORK") +data_path = os.path.join('/mnt/lustre-grete/usr/u12649/data/test', 'pannuke') + +result_dir = "/mnt/lustre-grete/usr/u12649/data/test/sampling_results" + +patch_shape = (1, 256, 256) +transforms = True + +sampler = None +label_dtype = torch.float32 + +if transforms: + geometric_seq, photometric_seq = build_transforms(patch_shape) + transform = partial(geometric_transforms, seq=geometric_seq) + raw_transform = partial(photometric_transforms, seq=photometric_seq) +else: + transform = None + raw_transform = sam_training.identity + + +semantic_dataset = get_pannuke_dataset( + path=data_path, + patch_shape=patch_shape, + ndim=2, + folds=["fold_1", "fold_2"], + custom_label_choice="semantic", + sampler=sampler, + label_dtype=label_dtype, + raw_transform=raw_transform, + download=True, + deterministic_indices=True, + transform=transform, +) + +instance_dataset = get_pannuke_dataset( + path=data_path, + patch_shape=patch_shape, + ndim=2, + folds=["fold_1", "fold_2"], + custom_label_choice="instances", + sampler=sampler, + label_dtype=label_dtype, + raw_transform=raw_transform, + download=True, + deterministic_indices=True, + transform=transform +) + +random_inst_dataset = get_pannuke_dataset( + path=data_path, + patch_shape=patch_shape, + ndim=2, + folds=["fold_1", "fold_2"], + custom_label_choice="instances", + sampler=MinTwoInstanceSampler(), + label_dtype=label_dtype, + raw_transform=raw_transform, + download=True, + transform=transform, +) + +random_sem_dataset = get_pannuke_dataset( + path=data_path, + patch_shape=patch_shape, + ndim=2, + folds=["fold_1", "fold_2"], + custom_label_choice="semantic", + sampler=MinTwoInstanceSampler(), + label_dtype=label_dtype, + raw_transform=raw_transform, + download=True, + transform=transform +) + + +def visualize_transformations(): + image_path = os.path.join(result_dir, "data_visualisation_transforms", "images") + label_path = os.path.join(result_dir, "data_visualisation_transforms", "labels") + os.makedirs(image_path, exist_ok=True) + os.makedirs(label_path, exist_ok=True) + + for idx, (image, label) in enumerate(semantic_dataset, start=1): + image = image.numpy() + label = label.numpy() + print(len(np.unique(image))) + image = image.transpose(1, 2, 0) + label = np.squeeze(label) + + imageio.imwrite(os.path.join(image_path, f"{idx:04}.tiff"), image) + imageio.imwrite(os.path.join(label_path, f"{idx:04}.tiff"), label) + if idx == 30: + break + +# import h5py +# h5_paths = get_pannuke_paths(data_path) +# for h5_path in h5_paths: +# with h5py.File(h5_path, "r") as file: +# instances = file["labels/instances"][:] +# semantics = file["labels/semantic"][:] +# for instance_label, semantic_label in zip(instances, semantics): + +visualize_transformations() +result_dict = { + "gamma": [], + "average unique indices drawn": [], + "covered samples in 10 iterations": [], +} + + +def check_sampled_indices(): + for gamma in np.linspace(0.5, 1, 6): + sampler = get_sampler(instance_dataset, semantic_dataset, gamma=1, path=data_path, split="train") + + uniques_per_sampler = [] + uniques_all_samplers = [] + for i in range(10): + indices = [] + for idx in sampler: + indices.append(idx) + uniques_per_sampler.append(len(np.unique(indices))) + uniques_all_samplers.extend(np.unique(indices).tolist()) + print(f"Unique indices for gamma {gamma}: {uniques_per_sampler}, {len(sampler)}") + print(f"Over 10 samplers unique indices: {len(np.unique(uniques_all_samplers))}") + result_dict["gamma"].append(gamma) + result_dict["average unique indices drawn"].append(np.mean(uniques_per_sampler)) + result_dict["covered samples in 10 iterations"].append(len(np.unique(uniques_all_samplers))) + + df = pd.DataFrame(result_dict) + df.to_csv(os.path.join(result_dir, "gamma_weighted_sampling.csv"), index=False) + + +def check_sampled_instances(): + results_dict = { + "gamma": [], + "1": [], + "2": [], + "3": [], + "4": [], + "5": [], + } + # for gamma in np.linspace(0.5, 1, 6): + # sampler = get_sampler(random_inst_dataset, random_sem_dataset, gamma=gamma, path=data_path, split="train") + # result_array = np.array([ + # [len(np.unique(random_inst_dataset[index][1][random_sem_dataset[index][1] == cell_type])) for cell_type in range(1, 6)] for index in tqdm(sampler) + # ]) + # # for index in sampler: + # # _, semantic_label = semantic_dataset[index] + # # _, instance_label = instance_dataset[index] + # # result_array.append([len(np.unique(instance_label[semantic_label == cell_type])) + # # for cell_type in range(1, 6)]) + # # result_array = np.array(result_array) + # class_instance_counts = np.sum(result_array, axis=0).tolist() + # print(f"Weighted sampling for gamma {gamma}: {class_instance_counts}") + # results_dict["gamma"].append(gamma) + # results_dict["1"].append(class_instance_counts[0]) + # results_dict["2"].append(class_instance_counts[1]) + # results_dict["3"].append(class_instance_counts[2]) + # results_dict["4"].append(class_instance_counts[3]) + # results_dict["5"].append(class_instance_counts[4]) + + + + sampler = RandomSampler(semantic_dataset) + for i in range(10): + result_array = np.array([ + [len(np.unique(random_inst_dataset[index][1][random_sem_dataset[index][1] == cell_type])) for cell_type in range(1, 6)] for index in tqdm(sampler) + ]) + class_instance_counts = np.sum(result_array, axis=0).tolist() + print(f"Random sampling: {class_instance_counts}") + results_dict["gamma"].append(str(i)) + results_dict["1"].append(class_instance_counts[0]) + results_dict["2"].append(class_instance_counts[1]) + results_dict["3"].append(class_instance_counts[2]) + results_dict["4"].append(class_instance_counts[3]) + results_dict["5"].append(class_instance_counts[4]) + + df = pd.DataFrame(results_dict) + print(df.head()) + df.to_csv(os.path.join(result_dir, "per_class_instances_gamma_random_mininstance.csv"), index=False) + +# check_sampled_instances() +# check_sampled_indices() + + +# check_sampled_instances() + + +# TODO emply logic to check how many indices are actually succcessfully sampled with a RandomSampler, +# i. e. surpassing the MinTwoInstanceSampler! --> maybe try hashing the images somehow and work with a set to +# check for unique samples + + +def get_array_hash(array) -> str: + data = bytes() + data += array.numpy().tobytes() + + return hashlib.shake_256(data).hexdigest(16).upper() + + +def check_random_indices(): + sampler = RandomSampler(random_sem_dataset) + unique_indices = [len(set([get_array_hash(random_sem_dataset[idx][1]) for idx in sampler])) for i in tqdm(range(10))] + print(unique_indices) + idx_dict = {"unique_samples": unique_indices} + df = pd.DataFrame(idx_dict) + df.to_csv(os.path.join(result_dir, "random_sampled_unique_indices.csv"), index=False) + +# check_random_indices() \ No newline at end of file diff --git a/experiments/semantic_segmentation/generalists/evaluate_pannuke.py b/experiments/semantic_segmentation/generalists/evaluate_pannuke.py index 38e211f..a138598 100644 --- a/experiments/semantic_segmentation/generalists/evaluate_pannuke.py +++ b/experiments/semantic_segmentation/generalists/evaluate_pannuke.py @@ -15,8 +15,8 @@ from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke - -ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external" +WORK = os.environ.get("WORK") +ROOT = os.path.join(WORK, "data", "eval_pannuke", "semantic_split") def evaluate_pannuke_semantic_segmentation(args): @@ -27,7 +27,7 @@ def evaluate_pannuke_semantic_segmentation(args): device = "cuda" if torch.cuda.is_available() else "cpu" # Get per class weights. - fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5") + fpath = os.path.join(WORK, "data", "eval_pannuke", "pannuke_fold_3.h5") fpath = "/" + fpath per_class_weights = extract_class_weights_for_pannuke(fpath=fpath) @@ -46,7 +46,7 @@ def evaluate_pannuke_semantic_segmentation(args): ) # Load the model weights - model_state = torch.load(checkpoint_path, map_location="cpu")["model_state"] + model_state = torch.load(checkpoint_path, map_location="cpu", weights_only=False)["model_state"] unetr.load_state_dict(model_state) unetr.to(device) unetr.eval() @@ -107,7 +107,7 @@ def _get_average_results(sq_per_image, fname): print(results) # Get average results per method. - fname = checkpoint_path.rsplit("/")[-2] # Fetches the name of the style of training for semantic segmentation. + fname = checkpoint_path.rsplit("/")[-3] # Fetches the name of the style of training for semantic segmentation. _get_average_results(sq_per_image, f"pathosam_{fname}.csv") diff --git a/experiments/semantic_segmentation/generalists/submit_training.py b/experiments/semantic_segmentation/generalists/submit_training.py index dabbaa4..2e717bf 100644 --- a/experiments/semantic_segmentation/generalists/submit_training.py +++ b/experiments/semantic_segmentation/generalists/submit_training.py @@ -16,7 +16,7 @@ def submit_batch_script(script_name, decoder_only, decoder_from_pretrained, save #SBATCH --job-name=patho-sam source ~/.bashrc -micromamba activate super +mamba activate sam2 """ # Prepare the python scripts python_script = "python train_pannuke.py " diff --git a/experiments/semantic_segmentation/generalists/train_pannuke.py b/experiments/semantic_segmentation/generalists/train_pannuke.py index 3de709c..817db43 100644 --- a/experiments/semantic_segmentation/generalists/train_pannuke.py +++ b/experiments/semantic_segmentation/generalists/train_pannuke.py @@ -1,5 +1,6 @@ import os from collections import OrderedDict +from functools import partial import torch import torch.utils.data as data_util @@ -11,10 +12,12 @@ import micro_sam.training as sam_training from micro_sam.instance_segmentation import get_unetr -from patho_sam.training import SemanticInstanceTrainer +from patho_sam.training import SemanticInstanceTrainer, get_sampler +from patho_sam.training.util import (calculate_class_weights_for_loss_weighting, geometric_transforms, + photometric_transforms, build_transforms) -def get_dataloaders(patch_shape, data_path): +def get_dataloaders(patch_shape, data_path, weighted_sampling: bool, transforms: bool): """This returns the PanNuke data loaders implemented in `torch-em`. https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py It will automatically download the PanNuke data. @@ -25,11 +28,18 @@ def get_dataloaders(patch_shape, data_path): i.e. a tensor of the same spatial shape as `x`, with semantic labels for objects. Important: the ID 0 is reserved for background and ensure you have all semantic classes. """ - raw_transform = sam_training.identity - sampler = MinTwoInstanceSampler() + sampler = None if weighted_sampling else MinTwoInstanceSampler() label_dtype = torch.float32 - dataset = get_pannuke_dataset( + if transforms: + geometric_seq, photometric_seq = build_transforms(patch_shape) + transform = partial(geometric_transforms, seq=geometric_seq) + raw_transform = partial(photometric_transforms, seq=photometric_seq) + else: + transform = None + raw_transform = sam_training.identity + + sem_dataset = get_pannuke_dataset( path=data_path, patch_shape=patch_shape, ndim=2, @@ -39,15 +49,44 @@ def get_dataloaders(patch_shape, data_path): label_dtype=label_dtype, raw_transform=raw_transform, download=True, + deterministic_indices=True, + transform=transform, + ) + + inst_dataset = get_pannuke_dataset( + path=data_path, + patch_shape=patch_shape, + ndim=2, + folds=["fold_1", "fold_2"], + custom_label_choice="instances", + sampler=sampler, + label_dtype=label_dtype, + raw_transform=raw_transform, + download=True, + deterministic_indices=True, + transform=transform, ) # Create custom splits. generator = torch.Generator().manual_seed(42) - train_dataset, val_dataset = data_util.random_split(dataset, [0.8, 0.2], generator=generator) + inst_train_dataset, inst_val_dataset = data_util.random_split(inst_dataset, [0.8, 0.2], generator=generator) + sem_train_dataset, sem_val_dataset = data_util.random_split(sem_dataset, [0.8, 0.2], generator=generator) + + # Get the weighted samplers + if weighted_sampling: + train_sampler = get_sampler(inst_train_dataset, sem_train_dataset, gamma=1, path=data_path, split="train") + val_sampler = get_sampler(inst_val_dataset, sem_val_dataset, gamma=1, path=data_path, split="val") + shuffle = False + + else: + train_sampler = None + shuffle = True # Get the dataloaders. - train_loader = torch_em.get_data_loader(train_dataset, batch_size=8, shuffle=True, num_workers=16) - val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=16) + train_loader = torch_em.get_data_loader(sem_train_dataset, batch_size=8, num_workers=1, sampler=train_sampler, + shuffle=shuffle) + val_loader = torch_em.get_data_loader(sem_val_dataset, batch_size=1, num_workers=1, sampler=None, + shuffle=True) return train_loader, val_loader @@ -60,12 +99,18 @@ def train_pannuke_semantic_segmentation(args): num_classes = 6 # available classes are [0, 1, 2, 3, 4, 5] checkpoint_path = args.checkpoint_path device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint_name = f"{model_type}/pannuke_semantic" + checkpoint_name = f"{model_type}/pannuke_semantic"\ train_loader, val_loader = get_dataloaders( - patch_shape=(1, 512, 512), data_path=os.path.join(args.input_path, "pannuke") + patch_shape=(1, 256, 256), data_path=os.path.join(args.input_path, "pannuke"), + weighted_sampling=args.weighted_sampling, transforms=args.transforms ) + # Sampling and weighted loss + class_weights = calculate_class_weights_for_loss_weighting() if args.weighted_loss else None + checkpoint_name += "-weighted_loss" if args.weighted_loss else "" + checkpoint_name += "-weighted_sampling" if args.weighted_sampling else "" + checkpoint_name += "-transforms" if args.transforms else "" # Whether we opt for finetuning decoder only or finetune the entire backbone. if args.decoder_only: freeze = ["image_encoder", "prompt_encoder", "mask_decoder"] @@ -124,6 +169,7 @@ def train_pannuke_semantic_segmentation(args): convert_inputs=convert_inputs, num_classes=num_classes, dice_weight=0, + class_weights=class_weights, ) trainer.fit(iterations=int(args.iterations), overwrite_training=False) @@ -163,5 +209,17 @@ def main(args): "--decoder_from_pretrained", action="store_true", help="Whether to train the decoder from scratch, or train the pretrained decoder (i.e. used for AIS)." ) + parser.add_argument( + "--weighted_sampling", action="store_true", + help="Whether to use weighted sampling for class balancing" + ) + parser.add_argument( + "--weighted_loss", action="store_true", + help="Whether to use weighted loss for class balancing in the computation of the loss" + ) + parser.add_argument( + "--transforms", action='store_true', + help="Whether to use cellvit-like data augmentations" + ) args = parser.parse_args() main(args) diff --git a/patho_sam/training/__init__.py b/patho_sam/training/__init__.py index 66a47b4..d6a2f0e 100644 --- a/patho_sam/training/__init__.py +++ b/patho_sam/training/__init__.py @@ -1,2 +1,2 @@ -from .util import histopathology_identity, get_train_val_split +from .util import histopathology_identity, get_train_val_split, get_sampler from .semantic_trainer import SemanticInstanceTrainer diff --git a/patho_sam/training/util.py b/patho_sam/training/util.py index cca9b2a..b524bd7 100644 --- a/patho_sam/training/util.py +++ b/patho_sam/training/util.py @@ -1,11 +1,14 @@ -from typing import Tuple, List - +from typing import Tuple, List, Callable +import os +from tqdm import tqdm import numpy as np +import pandas as pd import torch import torch.utils.data as data_util from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb +import kornia.augmentation as K CLASS_MAP = { @@ -98,7 +101,7 @@ def remap_labels(y: np.ndarray, name: str) -> np.ndarray: def calculate_class_weights_for_loss_weighting( - foreground_class_weights: List[float] = [0.4702, 0.1797, 0.2229, 0.0159, 0.1113], + foreground_class_weights: List[float] = [0.507, 0.1082, 0.2284, 0.0038, 0.1526], ) -> List[float]: """Calculates the class weights for weighting the cross entropy loss. @@ -119,7 +122,7 @@ def calculate_class_weights_for_loss_weighting( foreground_class_weights = np.array(foreground_class_weights) # Define the range for integer weighting. - background_weight, max_weight = 1, 10 + background_weight, max_weight = 1, 3 # Normalize the class weights. min_val, max_val = np.min(foreground_class_weights), np.max(foreground_class_weights) @@ -137,3 +140,284 @@ def calculate_class_weights_for_loss_weighting( final_weights_with_bg = [background_weight, *final_weights] return final_weights_with_bg + + +def get_sampling_weights(instance_dataset, semantic_dataset, gamma: float, input_path, split): + + # If weights for the split have already been extracted and saved, they are loaded + weights_csv_path = os.path.join(input_path, f"{split}_instance_sampling_weights.csv") + + if os.path.exists(weights_csv_path): + print(f"Sampling weights for the {split} set have already been extracted.") + df = pd.read_csv(weights_csv_path) + cell_type_presence = df.to_numpy() + + # This creates an array where each line represents a training sample and each column corresponds to the binary + # presence of each nucleus type (1, 2, 3, 4, 5) + + # Class-Pixel-level + # else: + # cell_type_presence = np.array( + # [[torch.sum(label == cell_type).item() for cell_type in range(1, 6)] + # for _, label in tqdm(dataset, desc="Extracting sampling weights")]) + + # df = pd.DataFrame(cell_type_presence) + # df.to_csv(weights_csv_path, index=False) + + # Class-Instance-level + else: + cell_type_presence = np.array( + [[len(np.unique(instance_label[semantic_label == cell_type])) + for cell_type in range(1, 6)]for (_, instance_label), (_, semantic_label) in + tqdm(zip(instance_dataset, semantic_dataset), total=len(instance_dataset))] + ) + + df = pd.DataFrame(cell_type_presence) + df.to_csv(weights_csv_path, index=False) + + binary_weight_factors = np.sum(cell_type_presence, axis=0) + + k = np.sum(binary_weight_factors) + + # This creates an array with the respective weight factor for each nucleus type + weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) + + # This applies the weight factor to all the training samples with respect to the set gamma value + img_weight = (1 - gamma) * np.max(cell_type_presence, axis=-1) + gamma * np.sum( + (cell_type_presence * weight_vector), axis=1) + + # This assigns the minimal non-zero sample weight to samples whose weight is 0 + img_weight[np.where(img_weight == 0)] = np.min( + img_weight[np.nonzero(img_weight)] + ) + + return torch.Tensor(img_weight) + + +def get_sampler(instance_dataset, semantic_dataset, gamma, path, split) -> data_util.Sampler: + pannuke_weights = get_sampling_weights(instance_dataset, semantic_dataset, gamma, path, split) + + sampler = data_util.WeightedRandomSampler( + weights=pannuke_weights, + replacement=True, + num_samples=len(instance_dataset), + ) + + return sampler + + +def geometric_transforms(x, y, seq): + x = torch.from_numpy(x.astype(np.float32) / 255.0).unsqueeze(0) + y = torch.from_numpy(y).float().unsqueeze(0).unsqueeze(0) + x, y = seq(x, y) + x = (x.squeeze(0).numpy() * 255).clip(0, 255).astype(np.uint8) + return x, y.squeeze(0).numpy() + + +def photometric_transforms(x, seq): + x = torch.from_numpy(x.astype(np.float32) / 255.0).unsqueeze(0) + x = seq(x) + x = (x.squeeze(0).numpy() * 255).clip(0, 255).astype(np.uint8) + return x + + +def build_transforms(patch_shape) -> Tuple[Callable, Callable]: + geometric_transforms_list = [ + K.RandomRotation90(p=0.5, times=(1, 2)), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomResizedCrop(size=(patch_shape[1], patch_shape[2]), + scale=(0.5, 0.5), p=0.15), + ] + photometric_transforms_list = [ + K.RandomGaussianBlur(kernel_size=(11, 11), sigma=(0.5, 2.0), p=0.2), + K.RandomGaussianNoise(mean=0.0, std=0.05, p=0.25), + K.ColorJitter(brightness=0.25, contrast=0.25, + saturation=0.1, hue=0.05, p=0.2), + ] + + geometric_seq = K.AugmentationSequential(*geometric_transforms_list, data_keys=["input", "mask"]) + photometric_seq = K.AugmentationSequential(*photometric_transforms_list, data_keys=["input"]) + + return geometric_seq, photometric_seq + +# def get_transforms(patch_shape) -> Tuple[Callable, Callable]: +# import kornia.augmentation as K +# transform_settings = { +# "randomrotate90": {"p": 0.5}, +# "horizontalflip": {"p": 0.5}, +# "verticalflip": {"p": 0.5}, +# "downscale": {"p": 0.15, "scale": 0.5}, # scale as fraction of original size +# "blur": {"p": 0.2, "kernel_size": 11}, # kernel_size must be odd +# "gaussnoise": {"p": 0.25, "std": 0.05}, +# "colorjitter": { +# "p": 0.2, +# "brightness": 0.25, +# "contrast": 0.25, +# "saturation": 0.1, +# "hue": 0.05 +# }, +# # "normalize": { +# # "mean": [0.5, 0.5, 0.5], +# # "std": [0.5, 0.5, 0.5] +# # } +# } + +# geometric_transforms_list = [] + +# photometric_transforms_list = [] + +# # Random 90° rotation +# geometric_transforms_list.append(K.RandomRotation90(p=transform_settings["randomrotate90"]["p"], times=(1, 2))) + +# # Horizontal flip +# geometric_transforms_list.append(K.RandomHorizontalFlip(p=transform_settings["horizontalflip"]["p"])) + +# # Vertical flip +# geometric_transforms_list.append(K.RandomVerticalFlip(p=transform_settings["verticalflip"]["p"])) + +# # Downscale (simulated via RandomResizedCrop) +# geometric_transforms_list.append( +# K.RandomResizedCrop( +# size=(patch_shape[1], patch_shape[2]), +# scale=(transform_settings["downscale"]["scale"], transform_settings["downscale"]["scale"]), +# p=transform_settings["downscale"]["p"] +# ) +# ) + +# # Blur +# photometric_transforms_list.append( +# K.RandomGaussianBlur( +# kernel_size=(transform_settings["blur"]["kernel_size"], transform_settings["blur"]["kernel_size"]), +# sigma=(0.1, 2.0), +# p=transform_settings["blur"]["p"] +# ) +# ) + +# # Gaussian noise +# photometric_transforms_list.append( +# K.RandomGaussianNoise( +# mean=0.0, +# std=transform_settings["gaussnoise"]["std"], +# p=transform_settings["gaussnoise"]["p"] +# ) +# ) + +# # Color jitter +# photometric_transforms_list.append( +# K.ColorJitter( +# brightness=transform_settings["colorjitter"]["brightness"], +# contrast=transform_settings["colorjitter"]["contrast"], +# saturation=transform_settings["colorjitter"]["saturation"], +# hue=transform_settings["colorjitter"]["hue"], +# p=transform_settings["colorjitter"]["p"] +# ) +# ) +# # Normalize +# # mean = torch.tensor(transform_settings.get("normalize", {}).get("mean", [0.5, 0.5, 0.5])) +# # std = torch.tensor(transform_settings.get("normalize", {}).get("std", [0.5, 0.5, 0.5])) +# # transform_list.append(K.Normalize(mean=mean, std=std)) + +# # Compose +# def geometric_transforms(x, y): +# x = torch.from_numpy(x.astype(np.float32) / 255.0).unsqueeze(0) +# y = torch.from_numpy(y).float().unsqueeze(0).unsqueeze(0) +# x, y = K.AugmentationSequential(*geometric_transforms_list, data_keys=["input", "mask"])(x, y) +# x = (x.squeeze(0).numpy() * 255).clip(0, 255).astype(np.uint8) # back to original type +# return x, y.squeeze(0).numpy() + +# def photometric_transforms(x): +# x = torch.from_numpy(x.astype(np.float32) / 255.0).unsqueeze(0) +# x = K.AugmentationSequential(*photometric_transforms_list, data_keys=["input"])(x) +# x = x.squeeze(0).numpy() +# x = (x * 255).clip(0, 255).astype(np.uint8) # back to original type +# return x + +# return geometric_transforms, photometric_transforms + + +def get_sampling_weights_cellvit(dataset, gamma: float): + """ This class balancing approach is modified from CellViT (Hörst et al. 2024) + """ + + # This creates an array where each line represents a training sample and each column corresponds to the binary + # presence of each nucleus type (1, 2, 3, 4, 5) + cell_type_presence = np.array( + [[int(cell_type in np.unique(label)) for cell_type in range(1, 6)] for _, label in dataset]) + + # We create an array of the number of samples that each nucleus type is represented in + binary_weight_factors = np.sum(cell_type_presence, axis=0) + + k = np.sum(binary_weight_factors) + + # This creates an array with the respective weight factor for each nucleus type + weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) + + # This applies the weight factor to all the training samples with respect to the set gamma value + img_weight = (1 - gamma) * np.max(cell_type_presence, axis=-1) + gamma * np.sum( + cell_type_presence * weight_vector, axis=-1 + ) + + # This assigns the minimal non-zero sample weight to samples whose weight is 0 + img_weight[np.where(img_weight == 0)] = np.min( + img_weight[np.nonzero(img_weight)] + ) + + return torch.Tensor(img_weight) + +# class DeterministicDataset(torch.utils.data.Dataset): +# def __init__(self, base_dataset): +# self.base_dataset = base_dataset + +# def __len__(self): +# return len(self.base_dataset) + +# def __getitem__(self, idx): +# # Access original dataset item without applying augmentations +# # Assumes your base dataset has `get_raw_item(idx)` or similar +# # If not, you can temporarily disable transforms +# item = self.base_dataset.get_raw_item(idx) +# return item + +class DeterministicSubset(torch.utils.data.Dataset): + def __init__(self, subset): + self.subset = subset + self.base_dataset = subset.dataset + self.indices = subset.indices + + def __len__(self): + return len(self.subset) + + def __getitem__(self, idx): + real_idx = self.indices[idx] + orig_transform = self.base_dataset.transform + self.base_dataset.transform = None + item = self.base_dataset[real_idx] + self.base_dataset.transform = orig_transform + return item + +class DeterministicDataset(torch.utils.data.Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + # Resolve Subset / ConcatDataset + ds = self.dataset + while hasattr(ds, 'dataset'): + if isinstance(ds, torch.utils.data.Subset): + idx = ds.indices[idx] + ds = ds.dataset + + if hasattr(ds, 'transform'): + orig_transform = ds.transform + ds.transform = None + + item = ds[idx] + + if hasattr(ds, 'transform'): + ds.transform = orig_transform + + return item \ No newline at end of file