From 00f4591c97580ccbc5d5ac363350dd7fd09d6b15 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 13 Feb 2026 09:32:29 +0100 Subject: [PATCH 1/8] first working commit --- pina/_src/condition/condition_base.py | 16 +- pina/_src/core/trainer.py | 33 +- pina/_src/data/aggregator.py | 58 +++ pina/_src/data/creator.py | 178 ++++++++ pina/_src/data/data_module.py | 620 ++++++-------------------- pina/_src/data/dummy_dataloader.py | 62 +++ pina/_src/problem/abstract_problem.py | 74 +-- 7 files changed, 517 insertions(+), 524 deletions(-) create mode 100644 pina/_src/data/aggregator.py create mode 100644 pina/_src/data/creator.py create mode 100644 pina/_src/data/dummy_dataloader.py diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py index b8290d717..44a8af2b7 100644 --- a/pina/_src/condition/condition_base.py +++ b/pina/_src/condition/condition_base.py @@ -9,6 +9,7 @@ from pina._src.condition.condition_interface import ConditionInterface from pina._src.core.graph import LabelBatch from pina._src.core.label_tensor import LabelTensor +from pina._src.data.dummy_dataloader import DummyDataloader class ConditionBase(ConditionInterface): @@ -85,7 +86,8 @@ def automatic_batching_collate_fn(cls, batch): if not batch: return {} instance_class = batch[0].__class__ - return instance_class.create_batch(batch) + batch = instance_class.create_batch(batch) + return batch @staticmethod def collate_fn(batch, condition): @@ -103,7 +105,11 @@ def collate_fn(batch, condition): return data def create_dataloader( - self, dataset, batch_size, shuffle, automatic_batching + self, + dataset, + batch_size, + automatic_batching, + **kwargs, ): """ Create a DataLoader for the condition. @@ -114,14 +120,14 @@ def create_dataloader( :rtype: torch.utils.data.DataLoader """ if batch_size == len(dataset): - pass # will be updated in the near future + return DummyDataloader(dataset) return DataLoader( dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, collate_fn=( partial(self.collate_fn, condition=self) if not automatic_batching else self.automatic_batching_collate_fn ), + batch_size=batch_size, + **kwargs, ) diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 7500be537..377b42fac 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -36,7 +36,7 @@ def __init__( test_size=0.0, val_size=0.0, compile=None, - repeat=None, + batching_mode="common_batch_size", automatic_batching=None, num_workers=None, pin_memory=None, @@ -61,9 +61,9 @@ def __init__( :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. Not supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"common_batch_size"``. ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset @@ -87,7 +87,7 @@ def __init__( train_size=train_size, test_size=test_size, val_size=val_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, compile=compile, ) @@ -127,8 +127,6 @@ def __init__( UserWarning, ) - repeat = repeat if repeat is not None else False - automatic_batching = ( automatic_batching if automatic_batching is not None else False ) @@ -144,7 +142,7 @@ def __init__( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -182,7 +180,7 @@ def _create_datamodule( test_size, val_size, batch_size, - repeat, + batching_mode, automatic_batching, pin_memory, num_workers, @@ -201,8 +199,9 @@ def _create_datamodule( :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data @@ -232,7 +231,7 @@ def _create_datamodule( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -284,7 +283,7 @@ def _check_input_consistency( train_size, test_size, val_size, - repeat, + batching_mode, automatic_batching, compile, ): @@ -298,8 +297,9 @@ def _check_input_consistency( test dataset. :param float val_size: The percentage of elements to include in the validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -309,8 +309,7 @@ def _check_input_consistency( check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) + check_consistency(batching_mode, str) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py new file mode 100644 index 000000000..c788132c2 --- /dev/null +++ b/pina/_src/data/aggregator.py @@ -0,0 +1,58 @@ +""" +Aggregator for multiple dataloaders. +""" + + +class _Aggregator: + """ + The class :class:`_Aggregator` is responsible for aggregating multiple + dataloaders into a single iterable object. It supports different batching + modes to accommodate various training requirements. + """ + + def __init__(self, dataloaders, batching_mode): + """ + Initialization of the :class:`_Aggregator` class. + + :param dataloaders: A dictionary mapping condition names to their + respective dataloaders. + :type dataloaders: dict[str, DataLoader] + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + """ + self.dataloaders = dataloaders + self.batching_mode = batching_mode + + def __len__(self): + """ + Return the length of the aggregated dataloader. + + :return: The length of the aggregated dataloader. + :rtype: int + """ + return max(len(dl) for dl in self.dataloaders.values()) + + def __iter__(self): + """ + Return an iterator over the aggregated dataloader. + + :return: An iterator over the aggregated dataloader. + :rtype: iterator + """ + if self.batching_mode == "separate_conditions": + for name, dl in self.dataloaders.items(): + for batch in dl: + yield {name: batch} + return + iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} + for _ in range(len(self)): + batch = {} + for name, it in iterators.items(): + try: + batch[name] = next(it) + except StopIteration: + iterators[name] = iter(self.dataloaders[name]) + batch[name] = next(iterators[name]) + yield batch diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py new file mode 100644 index 000000000..b0e6d37c1 --- /dev/null +++ b/pina/_src/data/creator.py @@ -0,0 +1,178 @@ +""" +Module defining the Creator class, responsible for creating dataloaders +for multiple conditions with various batching strategies. +""" + +import torch +from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + + +class _Creator: + """ + The class :class:`_Creator` is responsible for creating dataloaders for + multiple conditions based on specified batching strategies. It supports + different batching modes to accommodate various training requirements. + """ + + def __init__( + self, + batching_mode, + batch_size, + shuffle, + automatic_batching, + num_workers, + pin_memory, + conditions, + ): + """ + Initialization of the :class:`_Creator` class. + + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + :param batch_size: The batch size to use for dataloaders. If + ``batching_mode`` is ``"proportional"``, this represents the total + batch size across all conditions. + :type batch_size: int | None + :param shuffle: Whether to shuffle the data in the dataloaders. + :type shuffle: bool + :param automatic_batching: Whether to use automatic batching in the + dataloaders. + :type automatic_batching: bool + :param num_workers: The number of worker processes to use for data + loading. + :type num_workers: int + :param pin_memory: Whether to pin memory in the dataloaders. + :type pin_memory: bool + :param conditions: A dictionary mapping condition names to their + respective condition objects. + :type conditions: dict[str, Condition] + """ + self.batching_mode = batching_mode + self.batch_size = batch_size + self.shuffle = shuffle + self.automatic_batching = automatic_batching + self.num_workers = num_workers + self.pin_memory = pin_memory + self.conditions = conditions + + def _define_sampler(self, dataset, shuffle): + if torch.distributed.is_initialized(): + return DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + return RandomSampler(dataset) + return SequentialSampler(dataset) + + def _compute_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition based on the specified + batching mode. + + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + batch_sizes = {} + if self.batching_mode == "common_batch_size": + for name in datasets.keys(): + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + if self.batching_mode == "proportional": + return self._compute_proportional_batch_sizes(datasets) + if self.batching_mode == "separate_conditions": + for name in datasets.keys(): + condition = self.conditions[name] + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + raise ValueError(f"Unknown batching mode: {self.batching_mode}") + + def _compute_proportional_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition proportionally based on the + size of their datasets. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + # Compute number of elements per dataset + elements_per_dataset = { + dataset_name: len(dataset) + for dataset_name, dataset in datasets.items() + } + # Compute the total number of elements + total_elements = sum(el for el in elements_per_dataset.values()) + # Compute the portion of each dataset + portion_per_dataset = { + name: el / total_elements + for name, el in elements_per_dataset.items() + } + # Compute batch size per dataset. Ensure at least 1 element per + # dataset. + batch_size_per_dataset = { + name: max(1, int(portion * self.batch_size)) + for name, portion in portion_per_dataset.items() + } + # Adjust batch sizes to match the specified total batch size + tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) + if self.batch_size > tot_el_per_batch: + difference = self.batch_size - tot_el_per_batch + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] += 1 + difference -= 1 + if self.batch_size < tot_el_per_batch: + difference = tot_el_per_batch - self.batch_size + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] -= 1 + difference -= 1 + return batch_size_per_dataset + + def __call__(self, datasets): + """ + Create dataloaders for each condition based on the specified batching + mode. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their created + dataloaders. + :rtype: dict[str, DataLoader] + """ + # Compute batch sizes per condition based on batching_mode + batch_sizes = self._compute_batch_sizes(datasets) + dataloaders = {} + for name, dataset in datasets.items(): + dataloaders[name] = self.conditions[name].create_dataloader( + dataset=dataset, + batch_size=batch_sizes[name], + automatic_batching=self.automatic_batching, + sampler=self._define_sampler(dataset, self.shuffle), + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return dataloaders diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index f45236f0f..b39596eaf 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -12,227 +12,52 @@ from torch.utils.data.distributed import DistributedSampler from pina._src.core.label_tensor import LabelTensor from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset +from pina._src.data.creator import _Creator +from pina._src.data.aggregator import _Aggregator -class DummyDataloader: - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx) - else: - self.dataset = dataset.get_all_data() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset - - -class Collator: +class _ConditionSubset: """ - This callable class is used to collate the data points fetched from the - dataset. The collation is performed based on the type of dataset used and - on the batching strategy. + This class extends the :class:`torch.utils.data.Subset` class, allowing to + fetch the data from the dataset based on a list of indices. """ - def __init__( - self, max_conditions_lengths, automatic_batching, dataset=None - ): - """ - Initialize the object, setting the collate function based on whether - automatic batching is enabled or not. - - :param dict max_conditions_lengths: ``dict`` containing the maximum - number of data points to consider in a single batch for - each condition. - :param bool automatic_batching: Whether automatic PyTorch batching is - enabled or not. For more information, see the - :class:`~pina.data.data_module.PinaDataModule` class. - :param PinaDataset dataset: The dataset where the data is stored. - """ - - self.max_conditions_lengths = max_conditions_lengths - # Set the collate function based on the batching strategy - # collate_pina_dataloader is used when automatic batching is disabled - # collate_torch_dataloader is used when automatic batching is enabled - self.callable_function = ( - self._collate_torch_dataloader - if automatic_batching - else (self._collate_pina_dataloader) - ) - self.dataset = dataset - - # Set the function which performs the actual collation - if isinstance(self.dataset, PinaTensorDataset): - # If the dataset is a PinaTensorDataset, use this collate function - self._collate = self._collate_tensor_dataset - else: - # If the dataset is a PinaDataset, use this collate function - self._collate = self._collate_graph_dataset - - def _collate_pina_dataloader(self, batch): - """ - Function used to create a batch when automatic batching is disabled. - - :param list[int] batch: List of integers representing the indices of - the data points to be fetched. - :return: Dictionary containing the data points fetched from the dataset. - :rtype: dict - """ - # Call the fetch_from_idx_list method of the dataset - return self.dataset.fetch_from_idx_list(batch) - - def _collate_torch_dataloader(self, batch): - """ - Function used to collate the batch - - :param list[dict] batch: List of retrieved data. - :return: Dictionary containing the data points fetched from the dataset, - collated. - :rtype: dict - """ - - batch_dict = {} - if isinstance(batch, dict): - return batch - conditions_names = batch[0].keys() - # Condition names - for condition_name in conditions_names: - single_cond_dict = {} - condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict - return batch_dict - - @staticmethod - def _collate_tensor_dataset(data_list): - """ - Function used to collate the data when the dataset is a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param data_list: Elements to be collated. - :type data_list: list[torch.Tensor] | list[LabelTensor] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - """ - - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - """ - Function used to collate data when the dataset is a - :class:`~pina.data.dataset.PinaGraphDataset`. - - :param data_list: Elememts to be collated. - :type data_list: list[Data] | list[Graph] - :return: Batch of data. - :rtype: dict + def __init__(self, condition, indices, automatic_batching): + super().__init__() + self.condition = condition + self.indices = indices + self.automatic_batching = automatic_batching - :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. - """ - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_batch(data_list) - raise RuntimeError( - "Data must be Tensors or LabelTensor or pyG " - "torch_geometric.data.Data" - ) + def __len__(self): + return len(self.indices) - def __call__(self, batch): + def __getitem__(self, idx): """ - Perform the collation of data fetched from the dataset. The behavoior - of the function is set based on the batching strategy during class - initialization. + Fetch the data from the dataset based on the list of indices. - :param batch: List of retrieved data or sampled indices. - :type batch: list[int] | list[dict] - :return: Dictionary containing colleted data fetched from the dataset. + :param int idx: The index of the data to be fetched. + :return: The data corresponding to the given index. :rtype: dict """ - - return self.callable_function(batch) - - -class PinaSampler: - """ - This class is used to create the sampler instance based on the shuffle - parameter and the environment in which the code is running. - """ - - def __new__(cls, dataset): - """ - Instantiate and initialize the sampler. - - :param PinaDataset dataset: The dataset from which to sample. - :return: The sampler instance. - :rtype: :class:`torch.utils.data.Sampler` - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler + idx = self.indices[idx] + if not self.automatic_batching: + return idx + return self.condition[idx] + + def get_all_data(self): + data = self.condition[self.indices] + if "data" in data and isinstance(data["data"], list): + batch_fn = ( + LabelBatch.from_data_list + if isinstance(data["data"][0], Graph) + else Batch.from_data_list + ) + data["data"] = batch_fn(data["data"]) + data = { + "input": data["data"], + "target": data["data"].y, + } + return data class PinaDataModule(LightningDataModule): @@ -250,7 +75,7 @@ def __init__( val_size=0.1, batch_size=None, shuffle=True, - repeat=False, + batching_mode="separate_conditions", automatic_batching=None, num_workers=0, pin_memory=False, @@ -271,11 +96,9 @@ def __init__( Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: If ``True``, in case of batch size larger than the - number of elements in a specific condition, the elements are - repeated until the batch size is reached. If ``False``, the number - of elements in the batch is the minimum between the batch size and - the number of elements in the condition. Default is ``False``. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"separate_conditions"``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -302,10 +125,11 @@ def __init__( """ super().__init__() + self.problem = problem # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.repeat = repeat + self.batching_mode = batching_mode self.automatic_batching = automatic_batching # If batch size is None, num_workers has no effect @@ -327,41 +151,87 @@ def __init__( self.pin_memory = False else: self.pin_memory = pin_memory - - # Collect data - problem.collect_data() - - # Check if the splits are correct + self.problem.move_discretisation_into_conditions() self._check_slit_sizes(train_size, test_size, val_size) - # Split input data into subsets - splits_dict = {} if train_size > 0: - splits_dict["train"] = train_size self.train_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.train_dataloader = super().train_dataloader if test_size > 0: - splits_dict["test"] = test_size self.test_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.test_dataloader = super().test_dataloader if val_size > 0: - splits_dict["val"] = val_size self.val_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.data_splits = self._create_splits( - problem.collected_data, splits_dict + self._create_condition_splits(problem, train_size, test_size, val_size) + self.creator = _Creator( + batching_mode=batching_mode, + batch_size=batch_size, + shuffle=shuffle, + automatic_batching=automatic_batching, + num_workers=num_workers, + pin_memory=pin_memory, + conditions=problem.conditions, ) - self.transfer_batch_to_device = self._transfer_batch_to_device + + @staticmethod + def _check_slit_sizes(train_size, test_size, val_size): + """ + Check if the splits are correct. The splits sizes must be positive and + the sum of the splits must be 1. + + :param float train_size: The size of the training split. + :param float test_size: The size of the testing split. + :param float val_size: The size of the validation split. + + :raises ValueError: If at least one of the splits is negative. + :raises ValueError: If the sum of the splits is different + from 1. + """ + + if train_size < 0 or test_size < 0 or val_size < 0: + raise ValueError("The splits must be positive") + if abs(train_size + test_size + val_size - 1) > 1e-6: + raise ValueError("The sum of the splits must be 1") + + def _create_condition_splits( + self, problem, train_size, test_size, val_size + ): + self.split_idxs = {} + for condition_name, condition in problem.conditions.items(): + len_condition = len(condition) + # Create the indices for shuffling and splitting + indices = ( + torch.randperm(len_condition).tolist() + if self.shuffle + else list(range(len_condition)) + ) + + # Determine split sizes + train_end = int(train_size * len_condition) + test_end = train_end + int(test_size * len_condition) + + # Split indices + train_indices = indices[:train_end] + test_indices = indices[train_end:test_end] + val_indices = indices[test_end:] + splits = {} + splits["train"], splits["test"], splits["val"] = ( + train_indices, + test_indices, + val_indices, + ) + self.split_idxs[condition_name] = splits def setup(self, stage=None): """ @@ -374,209 +244,60 @@ def setup(self, stage=None): :raises ValueError: If the stage is neither "fit" nor "test". """ if stage == "fit" or stage is None: - self.train_dataset = PinaDatasetFactory( - self.data_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( - "train" - ), - automatic_batching=self.automatic_batching, - ) - if "val" in self.data_splits.keys(): - self.val_dataset = PinaDatasetFactory( - self.data_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( - "val" - ), + print("Sono qui") + self.train_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["train"], automatic_batching=self.automatic_batching, ) - elif stage == "test": - self.test_dataset = PinaDatasetFactory( - self.data_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), - automatic_batching=self.automatic_batching, - ) - else: - raise ValueError("stage must be either 'fit' or 'test'.") - - @staticmethod - def _split_condition(single_condition_dict, splits_dict): - """ - Split the condition into different stages. - - :param dict single_condition_dict: The condition to be split. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: A dictionary containing the split condition. - :rtype: dict - """ - - len_condition = len(single_condition_dict["input"]) - - lengths = [ - int(len_condition * length) for length in splits_dict.values() - ] - - remainder = len_condition - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - - splits_dict = { - k: max(1, v) for k, v in zip(splits_dict.keys(), lengths) - } - to_return_dict = {} - offset = 0 - - for stage, stage_len in splits_dict.items(): - to_return_dict[stage] = { - k: v[offset : offset + stage_len] - for k, v in single_condition_dict.items() - if k != "equation" - # Equations are NEVER dataloaded + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["train"]) > 0 } - if offset + stage_len >= len_condition: - offset = len_condition - 1 - continue - offset += stage_len - return to_return_dict - - def _create_splits(self, collector, splits_dict): - """ - Create the dataset objects putting data in the correct splits. - - :param Collector collector: The collector object containing the data. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: The dictionary containing the dataset objects. - :rtype: dict - """ - - # ----------- Auxiliary function ------------ - def _apply_shuffle(condition_dict, len_data): - idx = torch.randperm(len_data) - for k, v in condition_dict.items(): - if k == "equation": - continue - if isinstance(v, list): - condition_dict[k] = [v[i] for i in idx] - elif isinstance(v, LabelTensor): - condition_dict[k] = LabelTensor(v.tensor[idx], v.labels) - elif isinstance(v, torch.Tensor): - condition_dict[k] = v[idx] - else: - raise ValueError(f"Data type {type(v)} not supported") - - # ----------- End auxiliary function ------------ - - split_names = list(splits_dict.keys()) - dataset_dict = {name: {} for name in split_names} - for ( - condition_name, - condition_dict, - ) in collector.items(): - len_data = len(condition_dict["input"]) - if self.shuffle: - _apply_shuffle(condition_dict, len_data) - for key, data in self._split_condition( - condition_dict, splits_dict - ).items(): - dataset_dict[key].update({condition_name: data}) - return dataset_dict - - def _create_dataloader(self, split, dataset): - """ " - Create the dataloader for the given split. - - :param str split: The split on which to create the dataloader. - :param str dataset: The dataset to be used for the dataloader. - :return: The dataloader for the given split. - :rtype: torch.utils.data.DataLoader - """ - # Suppress the warning about num_workers. - # In many cases, especially for PINNs, - # serial data loading can outperform parallel data loading. - warnings.filterwarnings( - "ignore", - message=( - "The '(train|val|test)_dataloader' does not have many workers " - "which may be a bottleneck." - ), - module="lightning.pytorch.trainer.connectors.data_connector", - ) - # Use custom batching (good if batch size is large) - if self.batch_size is not None: - sampler = PinaSampler(dataset) - if self.automatic_batching: - collate = Collator( - self.find_max_conditions_lengths(split), - self.automatic_batching, - dataset=dataset, + print(self.train_datasets) + self.val_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["val"], + automatic_batching=self.automatic_batching, ) - else: - collate = Collator( - None, self.automatic_batching, dataset=dataset + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["val"]) > 0 + } + return + if stage == "test" or stage is None: + self.test_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["test"], + automatic_batching=self.automatic_batching, ) - return DataLoader( - dataset, - self.batch_size, - collate_fn=collate, - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["test"]) > 0 + } + else: + raise ValueError( + f"Invalid stage {stage}. Stage must be either 'fit' or 'test'." ) - dataloader = DummyDataloader(dataset) - dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0 - ) - self.transfer_batch_to_device = self._transfer_batch_to_device_dummy - return dataloader - - def find_max_conditions_lengths(self, split): - """ - Define the maximum length for each conditions. - - :param dict split: The split of the dataset. - :return: The maximum length per condition. - :rtype: dict - """ - - max_conditions_lengths = {} - for k, v in self.data_splits[split].items(): - if self.batch_size is None: - max_conditions_lengths[k] = len(v["input"]) - elif self.repeat: - max_conditions_lengths[k] = self.batch_size - else: - max_conditions_lengths[k] = min( - len(v["input"]), self.batch_size - ) - return max_conditions_lengths - - def val_dataloader(self): - """ - Create the validation dataloader. - - :return: The validation dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("val", self.val_dataset) def train_dataloader(self): - """ - Create the training dataloader + print(self.train_datasets) + return _Aggregator( + self.creator(self.train_datasets), + batching_mode="separate_conditions", + ) - :return: The training dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("train", self.train_dataset) + def val_dataloader(self): + print(self.val_datasets) + return _Aggregator( + self.creator(self.val_datasets), batching_mode="separate_conditions" + ) def test_dataloader(self): - """ - Create the testing dataloader - - :return: The testing dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("test", self.test_dataset) + return _Aggregator( + self.creator(self.test_datasets), + batching_mode="separate_conditions", + ) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): @@ -591,10 +312,9 @@ def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - return batch - def _transfer_batch_to_device(self, batch, device, dataloader_idx): + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. @@ -606,53 +326,7 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - - batch = [ - ( - k, - super(LightningDataModule, self).transfer_batch_to_device( - v, device, dataloader_idx - ), - ) - for k, v in batch.items() - ] - - return batch - - @staticmethod - def _check_slit_sizes(train_size, test_size, val_size): - """ - Check if the splits are correct. The splits sizes must be positive and - the sum of the splits must be 1. - - :param float train_size: The size of the training split. - :param float test_size: The size of the testing split. - :param float val_size: The size of the validation split. - - :raises ValueError: If at least one of the splits is negative. - :raises ValueError: If the sum of the splits is different - from 1. - """ - - if train_size < 0 or test_size < 0 or val_size < 0: - raise ValueError("The splits must be positive") - if abs(train_size + test_size + val_size - 1) > 1e-6: - raise ValueError("The sum of the splits must be 1") - - @property - def input(self): - """ - Return all the input points coming from all the datasets. - - :return: The input points for training. - :rtype: dict - """ - - to_return = {} - if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input - if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input - if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return["test"] = self.test_dataset.input + to_return = [] + for condition_name, condition in batch.items(): + to_return.append((condition_name, condition.to(device))) return to_return diff --git a/pina/_src/data/dummy_dataloader.py b/pina/_src/data/dummy_dataloader.py new file mode 100644 index 000000000..c236e9d30 --- /dev/null +++ b/pina/_src/data/dummy_dataloader.py @@ -0,0 +1,62 @@ +""" +Module containing the ``DummyDataloader`` class +""" + +import torch + + +class DummyDataloader: + """ + A dummy dataloader that returns the entire dataset in a single batch. This + is used when the batch size is ``None``. It supports both distributed and + non-distributed environments. In a distributed environment, it divides the + dataset across processes using the rank and world size, fetching only the + portion of data corresponding to the current process. In a non-distributed + environment, it fetches the entire dataset. + """ + + def __init__(self, dataset): + """ + Prepare a dataloader object that returns the entire dataset in a single + batch. Depending on the number of GPUs, the dataset is managed + as follows: + + - **Distributed Environment** (multiple GPUs): Divides dataset across + processes using the rank and world size. Fetches only portion of + data corresponding to the current process. + - **Non-Distributed Environment** (single GPU): Fetches the entire + dataset. + + :param PinaDataset dataset: The dataset object to be processed. + + .. note:: + This dataloader is used when the batch size is ``None``. + """ + + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + if len(dataset) < world_size: + raise RuntimeError( + "Dimension of the dataset smaller than world size." + " Increase the size of the partition or use a single GPU" + ) + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + self.dataset = dataset.fetch_from_idx_list(idx).to_batch() + else: + self.dataset = dataset.get_all_data().to_batch() + + def __iter__(self): + return self + + def __len__(self): + return 1 + + def __next__(self): + return self.dataset diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index cfaeb5bec..b781c8067 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -11,6 +11,7 @@ ) from pina._src.core.label_tensor import LabelTensor from pina._src.core.utils import merge_tensors, custom_warning_format +from pina._src.condition.condition import Condition class AbstractProblem(metaclass=ABCMeta): @@ -318,34 +319,49 @@ def add_points(self, new_points_dict): [self.discretised_domains[k], v] ) - def collect_data(self): + def move_discretisation_into_conditions(self): """ - Aggregate data from the problem's conditions into a single dictionary. + Move the discretised domains into their corresponding conditions. """ - data = {} - # Iterate over the conditions and collect data - for condition_name in self.conditions: - condition = self.conditions[condition_name] - # Check if the condition has an domain attribute - if hasattr(condition, "domain"): - # Only store the discretisation points if the domain is - # in the dictionary - if condition.domain in self.discretised_domains: - samples = self.discretised_domains[condition.domain][ - self.input_variables - ] - data[condition_name] = { - "input": samples, - "equation": condition.equation, - } - else: - # If the condition does not have a domain attribute, store - # the input and target points - keys = condition.__slots__ - values = [ - getattr(condition, name) - for name in keys - if getattr(condition, name) is not None - ] - data[condition_name] = dict(zip(keys, values)) - self._collected_data = data + + for name, cond in self.conditions.items(): + if hasattr(cond, "domain"): + domain = cond.domain + self.conditions[name] = Condition( + input=self.discretised_domains[cond.domain], + equation=cond.equation, + ) + self.conditions[name].domain = domain + self.conditions[name].problem = self + + # def collect_data(self): + # """ + # Aggregate data from the problem's conditions into a single dictionary. + # """ + # data = {} + # # Iterate over the conditions and collect data + # for condition_name in self.conditions: + # condition = self.conditions[condition_name] + # # Check if the condition has an domain attribute + # if hasattr(condition, "domain"): + # # Only store the discretisation points if the domain is + # # in the dictionary + # if condition.domain in self.discretised_domains: + # samples = self.discretised_domains[condition.domain][ + # self.input_variables + # ] + # data[condition_name] = { + # "input": samples, + # "equation": condition.equation, + # } + # else: + # # If the condition does not have a domain attribute, store + # # the input and target points + # keys = condition.__slots__ + # values = [ + # getattr(condition, name) + # for name in keys + # if getattr(condition, name) is not None + # ] + # data[condition_name] = dict(zip(keys, values)) + # self._collected_data = data From 975f0ef0d4a3a48d12214217eef171be8fa1b50b Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 13 Feb 2026 09:45:25 +0100 Subject: [PATCH 2/8] remove useless code in abstract_problem.py --- pina/_src/problem/abstract_problem.py | 88 ++++++--------------------- pina/data/__init__.py | 18 ------ 2 files changed, 19 insertions(+), 87 deletions(-) diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index b781c8067..cc2b9e042 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -43,43 +43,6 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - self._collected_data = {} - - @property - def collected_data(self): - """ - Return the collected data from the problem's conditions. If some domains - are not sampled, they will not be returned by collected data. - - :return: The collected data. Keys are condition names, and values are - dictionaries containing the input points and the corresponding - equations or target points. - :rtype: dict - """ - # collect data so far - self.collect_data() - # raise warning if some sample data are missing - if not self.are_all_domains_discretised: - warnings.formatwarning = custom_warning_format - warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.discretised_domains - else - "not sampled"}""" - for key in self.domains - ] - ) - warnings.warn( - "Some of the domains are still not sampled. Consider calling " - "problem.discretise_domain function for all domains before " - "accessing the collected data:\n" - f"{warning_message}", - RuntimeWarning, - ) - return self._collected_data - # back compatibility 0.1 @property def input_pts(self): @@ -323,6 +286,25 @@ def move_discretisation_into_conditions(self): """ Move the discretised domains into their corresponding conditions. """ + if not self.are_all_domains_discretised: + warnings.formatwarning = custom_warning_format + warnings.filterwarnings("always", category=RuntimeWarning) + warning_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { + "sampled" if key in self.discretised_domains + else + "not sampled"}""" + for key in self.domains + ] + ) + warnings.warn( + "Some of the domains are still not sampled. Consider calling " + "problem.discretise_domain function for all domains before " + "accessing the collected data:\n" + f"{warning_message}", + RuntimeWarning, + ) for name, cond in self.conditions.items(): if hasattr(cond, "domain"): @@ -333,35 +315,3 @@ def move_discretisation_into_conditions(self): ) self.conditions[name].domain = domain self.conditions[name].problem = self - - # def collect_data(self): - # """ - # Aggregate data from the problem's conditions into a single dictionary. - # """ - # data = {} - # # Iterate over the conditions and collect data - # for condition_name in self.conditions: - # condition = self.conditions[condition_name] - # # Check if the condition has an domain attribute - # if hasattr(condition, "domain"): - # # Only store the discretisation points if the domain is - # # in the dictionary - # if condition.domain in self.discretised_domains: - # samples = self.discretised_domains[condition.domain][ - # self.input_variables - # ] - # data[condition_name] = { - # "input": samples, - # "equation": condition.equation, - # } - # else: - # # If the condition does not have a domain attribute, store - # # the input and target points - # keys = condition.__slots__ - # values = [ - # getattr(condition, name) - # for name in keys - # if getattr(condition, name) is not None - # ] - # data[condition_name] = dict(zip(keys, values)) - # self._collected_data = data diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 2ecebecdd..f274d5bd9 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -7,26 +7,8 @@ from pina._src.data.data_module import ( PinaDataModule, - PinaSampler, - DummyDataloader, - Collator, - PinaSampler, -) - -from pina._src.data.dataset import ( - PinaDataset, - PinaTensorDataset, - PinaGraphDataset, - PinaDatasetFactory, ) __all__ = [ "PinaDataModule", - "PinaDataset", - "PinaSampler", - "DummyDataloader", - "Collator", - "PinaTensorDataset", - "PinaGraphDataset", - "PinaDatasetFactory", ] From 6eed64adf38acac49b627f71889cdb0161db82c4 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 17 Feb 2026 10:19:54 +0100 Subject: [PATCH 3/8] fix bugs --- pina/_src/data/aggregator.py | 11 +++++++---- pina/_src/data/data_module.py | 20 ++++++++------------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py index c788132c2..605af5d46 100644 --- a/pina/_src/data/aggregator.py +++ b/pina/_src/data/aggregator.py @@ -32,6 +32,8 @@ def __len__(self): :return: The length of the aggregated dataloader. :rtype: int """ + if self.batching_mode == "separate_conditions": + return sum(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values()) def __iter__(self): @@ -42,10 +44,11 @@ def __iter__(self): :rtype: iterator """ if self.batching_mode == "separate_conditions": - for name, dl in self.dataloaders.items(): - for batch in dl: - yield {name: batch} - return + # TODO: implement separate_conditions batching mode + raise NotImplementedError( + "Batching mode 'separate_conditions' is not implemented yet." + ) + iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} for _ in range(len(self)): batch = {} diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index b39596eaf..c364cde54 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -7,12 +7,9 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data -from torch.utils.data import DataLoader, SequentialSampler -from torch.utils.data.distributed import DistributedSampler -from pina._src.core.label_tensor import LabelTensor -from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset +from torch_geometric.data import Batch from pina._src.data.creator import _Creator +from pina._src.core.graph import LabelBatch, Graph from pina._src.data.aggregator import _Aggregator @@ -131,6 +128,7 @@ def __init__( self.shuffle = shuffle self.batching_mode = batching_mode self.automatic_batching = automatic_batching + self.batching_mode = batching_mode # If batch size is None, num_workers has no effect if batch_size is None and num_workers != 0: @@ -244,7 +242,6 @@ def setup(self, stage=None): :raises ValueError: If the stage is neither "fit" nor "test". """ if stage == "fit" or stage is None: - print("Sono qui") self.train_datasets = { name: _ConditionSubset( condition, @@ -254,7 +251,7 @@ def setup(self, stage=None): for name, condition in self.problem.conditions.items() if len(self.split_idxs[name]["train"]) > 0 } - print(self.train_datasets) + self.val_datasets = { name: _ConditionSubset( condition, @@ -265,6 +262,7 @@ def setup(self, stage=None): if len(self.split_idxs[name]["val"]) > 0 } return + if stage == "test" or stage is None: self.test_datasets = { name: _ConditionSubset( @@ -281,22 +279,20 @@ def setup(self, stage=None): ) def train_dataloader(self): - print(self.train_datasets) return _Aggregator( self.creator(self.train_datasets), - batching_mode="separate_conditions", + batching_mode=self.batching_mode, ) def val_dataloader(self): - print(self.val_datasets) return _Aggregator( - self.creator(self.val_datasets), batching_mode="separate_conditions" + self.creator(self.val_datasets), batching_mode=self.batching_mode ) def test_dataloader(self): return _Aggregator( self.creator(self.test_datasets), - batching_mode="separate_conditions", + batching_mode=self.batching_mode, ) @staticmethod From fc0a7e9fc6b79d0828624f8f01c36d08fe5a01fe Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 17 Feb 2026 10:20:46 +0100 Subject: [PATCH 4/8] fix batching bug with LabelTensor --- pina/_src/condition/data_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/_src/condition/data_manager.py b/pina/_src/condition/data_manager.py index b390cb580..3d0e5a1d5 100644 --- a/pina/_src/condition/data_manager.py +++ b/pina/_src/condition/data_manager.py @@ -119,7 +119,7 @@ def create_batch(items): if isinstance(sample, LabelTensor) else torch.stack ) - batch_data[k] = batch_fn(vals, dim=0) + batch_data[k] = batch_fn(vals) else: batch_data[k] = sample return batch_data From 47ec0908ed3977c576f5b6a8ce0ac6e3632a21e0 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 17 Feb 2026 10:21:59 +0100 Subject: [PATCH 5/8] add switch_dataloader_fn --- pina/_src/condition/condition_base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py index 44a8af2b7..8c6f0d269 100644 --- a/pina/_src/condition/condition_base.py +++ b/pina/_src/condition/condition_base.py @@ -34,6 +34,7 @@ def __init__(self, **kwargs): """ super().__init__() self.data = self.store_data(**kwargs) + self.has_custom_dataloader_fn = False @property def problem(self): @@ -131,3 +132,17 @@ def create_dataloader( batch_size=batch_size, **kwargs, ) + + def switch_dataloader_fn(self, create_dataloader_fn): + """ + Decorator to switch the dataloader function for a condition. + + :param create_dataloader_fn: The new dataloader function to use. + :type create_dataloader_fn: function + :return: The decorated function with the new dataloader function. + :rtype: function + """ + # Replace the create_dataloader method of the ConditionBase class with + # the new function + self.has_custom_dataloader_fn = True + self.create_dataloader = create_dataloader_fn From f36efe28d3c696a25e45aadc251f5b462816c698 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 2 Mar 2026 16:01:27 +0100 Subject: [PATCH 6/8] fix bugs --- pina/_src/core/trainer.py | 18 ++++++++++++++++++ pina/_src/data/data_module.py | 11 +++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 377b42fac..939da67f6 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -131,12 +131,30 @@ def __init__( automatic_batching if automatic_batching is not None else False ) + if batch_size is None and batching_mode != "common_batch_size": + warnings.warn( + "Batching mode is set to " + f"{batching_mode} but batch_size is None. " + "Batching mode will be set to common_batch_size.", + UserWarning, + ) + batching_mode = "common_batch_size" + + if batch_size == 1 and batching_mode == "proportional": + warnings.warn( + "Batching mode is set to proportional but batch_size is 1. " + "Batching mode will be set to common_batch_size.", + UserWarning, + ) + batching_mode = "common_batch_size" + # set attributes self.compile = compile self.solver = solver self.batch_size = batch_size self._move_to_device() self.data_module = None + self._create_datamodule( train_size=train_size, test_size=test_size, diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index c364cde54..ef9e6714d 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -72,7 +72,7 @@ def __init__( val_size=0.1, batch_size=None, shuffle=True, - batching_mode="separate_conditions", + batching_mode="common_batch_size", automatic_batching=None, num_workers=0, pin_memory=False, @@ -95,7 +95,7 @@ def __init__( Default ``True``. :param str batching_mode: The batching mode to use. Options are ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. Default is ``"separate_conditions"``. + ``"separate_conditions"``. Default is ``"common_batch_size"``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -241,7 +241,7 @@ def setup(self, stage=None): :raises ValueError: If the stage is neither "fit" nor "test". """ - if stage == "fit" or stage is None: + if stage in ("fit", None): self.train_datasets = { name: _ConditionSubset( condition, @@ -261,9 +261,8 @@ def setup(self, stage=None): for name, condition in self.problem.conditions.items() if len(self.split_idxs[name]["val"]) > 0 } - return - if stage == "test" or stage is None: + if stage in ("test", None): self.test_datasets = { name: _ConditionSubset( condition, @@ -273,7 +272,7 @@ def setup(self, stage=None): for name, condition in self.problem.conditions.items() if len(self.split_idxs[name]["test"]) > 0 } - else: + if stage not in ("fit", "test", None): raise ValueError( f"Invalid stage {stage}. Stage must be either 'fit' or 'test'." ) From fea450b60fbbdcc84fc1e2d6582a4b4d25438057 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 3 Mar 2026 09:59:58 +0100 Subject: [PATCH 7/8] bug fix and add tests --- pina/_src/condition/condition_base.py | 2 +- pina/_src/core/trainer.py | 6 +- tests/test_data/test_data_module.py | 331 ------------------------- tests/test_data/test_graph_dataset.py | 138 ----------- tests/test_data/test_tensor_dataset.py | 86 ------- tests/test_datamodule.py | 318 ++++++++++++++++++++++++ 6 files changed, 324 insertions(+), 557 deletions(-) delete mode 100644 tests/test_data/test_data_module.py delete mode 100644 tests/test_data/test_graph_dataset.py delete mode 100644 tests/test_data/test_tensor_dataset.py create mode 100644 tests/test_datamodule.py diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py index 8c6f0d269..0d1a8cb15 100644 --- a/pina/_src/condition/condition_base.py +++ b/pina/_src/condition/condition_base.py @@ -142,7 +142,7 @@ def switch_dataloader_fn(self, create_dataloader_fn): :return: The decorated function with the new dataloader function. :rtype: function """ - # Replace the create_dataloader method of the ConditionBase class with + # Replace the create_dataloader method of the ConditionBase class with # the new function self.has_custom_dataloader_fn = True self.create_dataloader = create_dataloader_fn diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 939da67f6..d18350d14 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -140,7 +140,11 @@ def __init__( ) batching_mode = "common_batch_size" - if batch_size == 1 and batching_mode == "proportional": + if ( + batch_size is not None + and batch_size <= len(solver.problem.conditions) + and batching_mode == "proportional" + ): warnings.warn( "Batching mode is set to proportional but batch_size is 1. " "Batching mode will be set to common_batch_size.", diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py deleted file mode 100644 index 9fd2d36ee..000000000 --- a/tests/test_data/test_data_module.py +++ /dev/null @@ -1,331 +0,0 @@ -import torch -import pytest -from pina.data import PinaDataModule -from pina.data import PinaTensorDataset, PinaGraphDataset -from pina.problem.zoo import SupervisedProblem -from pina.graph import RadiusGraph -from pina.data import DummyDataloader -from pina import Trainer -from pina.solver import SupervisedSolver -from torch_geometric.data import Batch -from torch.utils.data import DataLoader - -input_tensor = torch.rand((100, 10)) -output_tensor = torch.rand((100, 2)) - -x = torch.rand((100, 50, 10)) -pos = torch.rand((100, 50, 2)) -input_graph = [ - RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) -] -output_graph = torch.rand((100, 50, 10)) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -def test_constructor(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) - PinaDataModule(problem) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize( - "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.7, 0.3, 0)] -) -def test_setup_train(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule( - problem, train_size=train_size, val_size=val_size, test_size=test_size - ) - dm.setup() - assert hasattr(dm, "train_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.train_dataset, PinaTensorDataset) - else: - assert isinstance(dm.train_dataset, PinaGraphDataset) - # assert len(dm.train_dataset) == int(len(input_) * train_size) - if test_size > 0: - assert hasattr(dm, "test_dataset") - assert dm.test_dataset is None - else: - assert not hasattr(dm, "test_dataset") - assert hasattr(dm, "val_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.val_dataset, PinaTensorDataset) - else: - assert isinstance(dm.val_dataset, PinaGraphDataset) - # assert len(dm.val_dataset) == int(len(input_) * val_size) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize( - "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.0, 0.0, 1.0)] -) -def test_setup_test(input_, output_, train_size, val_size, test_size): - problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule( - problem, train_size=train_size, val_size=val_size, test_size=test_size - ) - dm.setup(stage="test") - if train_size > 0: - assert hasattr(dm, "train_dataset") - assert dm.train_dataset is None - else: - assert not hasattr(dm, "train_dataset") - if val_size > 0: - assert hasattr(dm, "val_dataset") - assert dm.val_dataset is None - else: - assert not hasattr(dm, "val_dataset") - - assert hasattr(dm, "test_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.test_dataset, PinaTensorDataset) - else: - assert isinstance(dm.test_dataset, PinaGraphDataset) - # assert len(dm.test_dataset) == int(len(input_) * test_size) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -def test_dummy_dataloader(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, - batch_size=10, - train_size=0.7, - val_size=0.3, - test_size=0.0, - automatic_batching=automatic_batching, - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 7 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 3 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - - -from pina import LabelTensor - -input_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) -output_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) - -x = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) -pos = LabelTensor(torch.rand((100, 50, 2)), ["x", "y"]) -input_graph = [ - RadiusGraph(x=x[i], pos=pos[i], radius=0.1) for i in range(len(x)) -] -output_graph = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) - - -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -@pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader_labels(input_, output_, automatic_batching): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, - batch_size=10, - train_size=0.7, - val_size=0.3, - test_size=0.0, - automatic_batching=automatic_batching, - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 7 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["input"].x, LabelTensor) - assert data["data"]["input"].x.labels == ["u", "v", "w"] - assert data["data"]["input"].pos.labels == ["x", "y"] - else: - assert isinstance(data["data"]["input"], LabelTensor) - assert data["data"]["input"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["target"], LabelTensor) - assert data["data"]["target"].labels == ["u", "v", "w"] - - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 3 - data = next(iter(dataloader)) - assert isinstance(data, dict) - if isinstance(input_, list): - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["input"].x, LabelTensor) - assert data["data"]["input"].x.labels == ["u", "v", "w"] - assert data["data"]["input"].pos.labels == ["x", "y"] - else: - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["input"], LabelTensor) - assert data["data"]["input"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["target"], torch.Tensor) - assert data["data"]["target"].labels == ["u", "v", "w"] - - -def test_get_all_data(): - input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) - target = input - - problem = SupervisedProblem(input, target) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700 - assert torch.isclose( - datamodule.train_dataset.get_all_data()["data"]["input"], input[:700] - ).all() - assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100 - assert torch.isclose( - datamodule.val_dataset.get_all_data()["data"]["input"], input[900:] - ).all() - assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200 - assert torch.isclose( - datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900] - ).all() - - -def test_input_propery_tensor(): - input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) - target = input - - problem = SupervisedProblem(input, target) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - input_ = datamodule.input - assert isinstance(input_, dict) - assert isinstance(input_["train"], dict) - assert isinstance(input_["val"], dict) - assert isinstance(input_["test"], dict) - assert torch.isclose(input_["train"]["data"], input[:700]).all() - assert torch.isclose(input_["val"]["data"], input[900:]).all() - assert torch.isclose(input_["test"]["data"], input[700:900]).all() - - -def test_input_propery_graph(): - problem = SupervisedProblem(input_graph, output_graph) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - input_ = datamodule.input - assert isinstance(input_, dict) - assert isinstance(input_["train"], dict) - assert isinstance(input_["val"], dict) - assert isinstance(input_["test"], dict) - assert isinstance(input_["train"]["data"], list) - assert isinstance(input_["val"]["data"], list) - assert isinstance(input_["test"]["data"], list) - assert len(input_["train"]["data"]) == 70 - assert len(input_["val"]["data"]) == 10 - assert len(input_["test"]["data"]) == 20 diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py deleted file mode 100644 index 3a63f7ec6..000000000 --- a/tests/test_data/test_graph_dataset.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch -import pytest -from pina.data import PinaDatasetFactory, PinaGraphDataset -from pina.graph import KNNGraph -from torch_geometric.data import Data - -x = torch.rand((100, 20, 10)) -pos = torch.rand((100, 20, 2)) -input_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x, pos) -] -output_ = torch.rand((100, 20, 10)) - -x_2 = torch.rand((50, 20, 10)) -pos_2 = torch.rand((50, 20, 2)) -input_2_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x_2, pos_2) -] -output_2_ = torch.rand((50, 20, 10)) - - -# Problem with a single condition -conditions_dict_single = { - "data": { - "input": input_, - "target": output_, - } -} -max_conditions_lengths_single = {"data": 100} - -# Problem with multiple conditions -conditions_dict_multi = { - "data_1": { - "input": input_, - "target": output_, - }, - "data_2": { - "input": input_2_, - "target": output_2_, - }, -} - -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_constructor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaGraphDataset) - assert len(dataset) == 100 - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_getitem(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - data = dataset[50] - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 60)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) - - data = dataset.fetch_from_idx_list([i for i in range(20)]) - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 1200)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) - - -def test_input_single_condition(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data"], list) - assert all([isinstance(d, Data) for d in input_["data"]]) - - -def test_input_multi_condition(): - dataset = PinaDatasetFactory( - conditions_dict_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data_1"], list) - assert all([isinstance(d, Data) for d in input_["data_1"]]) - assert isinstance(input_["data_2"], list) - assert all([isinstance(d, Data) for d in input_["data_2"]]) diff --git a/tests/test_data/test_tensor_dataset.py b/tests/test_data/test_tensor_dataset.py deleted file mode 100644 index 9e348c942..000000000 --- a/tests/test_data/test_tensor_dataset.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import pytest -from pina.data import PinaDatasetFactory, PinaTensorDataset - -input_tensor = torch.rand((100, 10)) -output_tensor = torch.rand((100, 2)) - -input_tensor_2 = torch.rand((50, 10)) -output_tensor_2 = torch.rand((50, 2)) - -conditions_dict_single = { - "data": { - "input": input_tensor, - "target": output_tensor, - } -} - -conditions_dict_single_multi = { - "data_1": { - "input": input_tensor, - "target": output_tensor, - }, - "data_2": { - "input": input_tensor_2, - "target": output_tensor_2, - }, -} - -max_conditions_lengths_single = {"data": 100} - -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} - - -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi), - ], -) -def test_constructor_tensor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaTensorDataset) - - -def test_getitem_single(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=False, - ) - - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data"] - assert sorted(list(tensors["data"].keys())) == ["input", "target"] - assert isinstance(tensors["data"]["input"], torch.Tensor) - assert tensors["data"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data"]["target"], torch.Tensor) - assert tensors["data"]["target"].shape == torch.Size((70, 2)) - - -def test_getitem_multi(): - dataset = PinaDatasetFactory( - conditions_dict_single_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=False, - ) - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data_1", "data_2"] - assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] - assert isinstance(tensors["data_1"]["input"], torch.Tensor) - assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data_1"]["target"], torch.Tensor) - assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) - - assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] - assert isinstance(tensors["data_2"]["input"], torch.Tensor) - assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) - assert isinstance(tensors["data_2"]["target"], torch.Tensor) - assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py new file mode 100644 index 000000000..8419a68f2 --- /dev/null +++ b/tests/test_datamodule.py @@ -0,0 +1,318 @@ +import torch +import pytest +from pina.data import PinaDataModule + +# from pina.data import PinaTensorDataset, PinaGraphDataset +from pina.problem.zoo import SupervisedProblem +from pina.graph import RadiusGraph + +# from pina.data import DummyDataloader +from pina._src.data.data_module import _ConditionSubset +from pina import Trainer +from pina.solver import SupervisedSolver +from torch_geometric.data import Batch +from torch.utils.data import DataLoader +from pina.problem.zoo import Poisson2DSquareProblem +from pina._src.data.aggregator import _Aggregator +from pina.solver import PINN + + +def _create_tensor_data(): + input_tensor = torch.rand((100, 10)) + output_tensor = torch.rand((100, 2)) + return input_tensor, output_tensor + + +def _create_graph_data(): + x = torch.rand((100, 50, 10)) + pos = torch.rand((100, 50, 2)) + input_graph = [ + RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) + ] + output_graph = torch.rand((100, 50, 2)) + return input_graph, output_graph + + +def test_init_tensor(): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"data"} + assert isinstance(dm.split_idxs["data"], dict) + assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["data"]["train"], list) + assert isinstance(dm.split_idxs["data"]["val"], list) + assert isinstance(dm.split_idxs["data"]["test"], list) + assert len(dm.split_idxs["data"]["train"]) == 70 + assert len(dm.split_idxs["data"]["val"]) == 10 + assert len(dm.split_idxs["data"]["test"]) == 20 + + +def test_init_graph(): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"data"} + assert isinstance(dm.split_idxs["data"], dict) + assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["data"]["train"], list) + assert isinstance(dm.split_idxs["data"]["val"], list) + assert isinstance(dm.split_idxs["data"]["test"], list) + assert len(dm.split_idxs["data"]["train"]) == 70 + assert len(dm.split_idxs["data"]["val"]) == 10 + assert len(dm.split_idxs["data"]["test"]) == 20 + + +def test_init_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + dm = PinaDataModule(problem) + assert dm.problem == problem + assert dm.trainer is None + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == {"D", "boundary"} + assert isinstance(dm.split_idxs["D"], dict) + assert set(dm.split_idxs["D"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["D"]["train"], list) + assert isinstance(dm.split_idxs["D"]["val"], list) + assert isinstance(dm.split_idxs["D"]["test"], list) + assert len(dm.split_idxs["D"]["train"]) == 70 + assert len(dm.split_idxs["D"]["val"]) == 10 + assert len(dm.split_idxs["D"]["test"]) == 20 + + assert isinstance(dm.split_idxs["boundary"], dict) + assert set(dm.split_idxs["boundary"].keys()) == {"train", "val", "test"} + assert isinstance(dm.split_idxs["boundary"]["train"], list) + assert isinstance(dm.split_idxs["boundary"]["val"], list) + assert isinstance(dm.split_idxs["boundary"]["test"], list) + assert len(dm.split_idxs["boundary"]["train"]) == 7 + assert len(dm.split_idxs["boundary"]["val"]) == 1 + assert len(dm.split_idxs["boundary"]["test"]) == 2 + + +def test_setup_tensor(): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"data"} + assert isinstance(dm.train_datasets["data"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"data"} + assert isinstance(dm.val_datasets["data"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"data"} + assert isinstance(dm.test_datasets["data"], _ConditionSubset) + + +def test_setup_graph(): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"data"} + assert isinstance(dm.train_datasets["data"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"data"} + assert isinstance(dm.val_datasets["data"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"data"} + assert isinstance(dm.test_datasets["data"], _ConditionSubset) + + +def test_setup_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + dm = PinaDataModule(problem) + dm.setup() + assert hasattr(dm, "train_datasets") + assert isinstance(dm.train_datasets, dict) + assert set(dm.train_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.train_datasets["D"], _ConditionSubset) + assert isinstance(dm.train_datasets["boundary"], _ConditionSubset) + assert hasattr(dm, "val_datasets") + assert isinstance(dm.val_datasets, dict) + assert set(dm.val_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.val_datasets["D"], _ConditionSubset) + assert isinstance(dm.val_datasets["boundary"], _ConditionSubset) + assert hasattr(dm, "test_datasets") + assert isinstance(dm.test_datasets, dict) + assert set(dm.test_datasets.keys()) == {"D", "boundary"} + assert isinstance(dm.test_datasets["D"], _ConditionSubset) + assert isinstance(dm.test_datasets["boundary"], _ConditionSubset) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_tensor(batch_size): + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + trainer = Trainer( + solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + train_size=0.7, + val_size=0.2, + test_size=0.1, + ) + dm = trainer.data_module + dm.setup() + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 70 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 10 + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_graph(batch_size): + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + trainer = Trainer( + solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), + train_size=0.7, + val_size=0.2, + test_size=0.1, + batch_size=batch_size, + ) + dm = trainer.data_module + dm.setup() + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 70 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["target"], torch.Tensor) + assert ( + len(data["data"]["input"]) == batch_size + if batch_size is not None + else 10 + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +def test_dataloader_poisson_cbs(batch_size): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + trainer = Trainer( + solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + val_size=0.1, + test_size=0.2, + train_size=0.7, + batching_mode="common_batch_size", + ) + dm = trainer.data_module + dm.setup() + + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == batch_size if batch_size is not None else 70 + ) + assert ( + len(data["boundary"]["input"]) == min(batch_size, 7) + if batch_size is not None + else 7 + ) + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == min(batch_size, 10) + if batch_size is not None + else 10 + ) + assert ( + len(data["boundary"]["input"]) == min(batch_size, 1) + if batch_size is not None + else 1 + ) + + +@pytest.mark.parametrize("batch_size", [None, 5, 20]) +def test_dataloader_poisson_proportional(batch_size): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="grid") + trainer = Trainer( + solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + batch_size=batch_size, + val_size=0.1, + test_size=0.2, + train_size=0.7, + batching_mode="proportional", + ) + dm = trainer.data_module + dm.setup() + + dataloader = dm.train_dataloader() + assert isinstance(dataloader, _Aggregator) + data = next(iter(dataloader)) + assert isinstance(data, dict) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["D"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert isinstance(data["boundary"]["input"], torch.Tensor) + assert ( + len(data["D"]["input"]) == batch_size - 1 + if batch_size is not None + else 70 + ) + assert len(data["boundary"]["input"]) == 1 if batch_size is not None else 7 From 4c693d4965674ffb631827265b737b014495c64a Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 4 Mar 2026 09:29:15 +0100 Subject: [PATCH 8/8] fix common_batch_size iteration bug --- pina/_src/data/creator.py | 4 ++++ pina/_src/data/data_module.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py index b0e6d37c1..0e84aef72 100644 --- a/pina/_src/data/creator.py +++ b/pina/_src/data/creator.py @@ -166,7 +166,11 @@ def __call__(self, datasets): # Compute batch sizes per condition based on batching_mode batch_sizes = self._compute_batch_sizes(datasets) dataloaders = {} + if self.batching_mode == "common_batch_size": + max_len = max(len(dataset) for dataset in datasets.values()) for name, dataset in datasets.items(): + if self.batching_mode == "common_batch_size": + dataset.max_len = max_len dataloaders[name] = self.conditions[name].create_dataloader( dataset=dataset, batch_size=batch_sizes[name], diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index ef9e6714d..d0fb5989a 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -24,9 +24,11 @@ def __init__(self, condition, indices, automatic_batching): self.condition = condition self.indices = indices self.automatic_batching = automatic_batching + self.length = len(self.indices) + self.max_len = self.length def __len__(self): - return len(self.indices) + return self.max_len def __getitem__(self, idx): """ @@ -36,6 +38,8 @@ def __getitem__(self, idx): :return: The data corresponding to the given index. :rtype: dict """ + if idx >= self.length: + idx = idx % self.length idx = self.indices[idx] if not self.automatic_batching: return idx