From d469b552f7cbdfe11334681609425e528afc9916 Mon Sep 17 00:00:00 2001 From: The dataclass_array Authors Date: Tue, 13 Jan 2026 14:36:51 -0800 Subject: [PATCH] Fix _ArrayField.inner_shape for scalar values in PyTorch contexts. Corrected an edge case in `_ArrayField.inner_shape` where, when called on a scalar value (empty `full_shape`) within a PyTorch environment, it now returns an empty shape `()` to prevent shape validation errors, particularly when used inside `torch.vmap`. PiperOrigin-RevId: 855886831 --- dataclass_array/array_dataclass.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dataclass_array/array_dataclass.py b/dataclass_array/array_dataclass.py index de884b3..65c2cae 100644 --- a/dataclass_array/array_dataclass.py +++ b/dataclass_array/array_dataclass.py @@ -1122,11 +1122,9 @@ def inner_shape(self) -> Shape: """Returns the the static shape resolved for the current value.""" # torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally, # messing up shape inference. - if ( - enp.lazy.has_torch - and isinstance(self.value, int) - and self.value == 0 - ): + if not self.full_shape and self.inner_shape_non_static: + if enp.lazy.has_torch: + return () return () if not self.inner_shape_non_static: return ()