diff --git a/dataclass_array/array_dataclass.py b/dataclass_array/array_dataclass.py index de884b3..65c2cae 100644 --- a/dataclass_array/array_dataclass.py +++ b/dataclass_array/array_dataclass.py @@ -1122,11 +1122,9 @@ def inner_shape(self) -> Shape: """Returns the the static shape resolved for the current value.""" # torch.func.vmap calls `tree_unflatten([0] * num_leaves)` internally, # messing up shape inference. - if ( - enp.lazy.has_torch - and isinstance(self.value, int) - and self.value == 0 - ): + if not self.full_shape and self.inner_shape_non_static: + if enp.lazy.has_torch: + return () return () if not self.inner_shape_non_static: return ()