diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index f677c197dc..d2ab353d43 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -16,7 +16,16 @@ from zarr.core.buffer import Buffer, BufferPrototype -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +__all__ = [ + "ByteGetter", + "ByteSetter", + "Store", + "SupportsDeleteSync", + "SupportsGetSync", + "SupportsSetSync", + "SupportsSyncStore", + "set_or_delete", +] @dataclass(frozen=True, slots=True) @@ -700,6 +709,31 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +@runtime_checkable +class SupportsGetSync(Protocol): + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: ... + + +@runtime_checkable +class SupportsSetSync(Protocol): + def set_sync(self, key: str, value: Buffer) -> None: ... + + +@runtime_checkable +class SupportsDeleteSync(Protocol): + def delete_sync(self, key: str) -> None: ... + + +@runtime_checkable +class SupportsSyncStore(SupportsGetSync, SupportsSetSync, SupportsDeleteSync, Protocol): ... + + async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: """Set or delete a value in a byte setter diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 62ceff7659..d05731d640 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -318,8 +318,6 @@ def _encode_sync( chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - # Since blosc only support host memory, we convert the input and output of the encoding - # between numpy array and buffer return chunk_spec.prototype.buffer.from_bytes( self._blosc_codec.encode(chunk_bytes.as_numpy_array()) ) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..4412ffa705 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,8 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -13,6 +13,7 @@ BytesBytesCodec, Codec, CodecPipeline, + SupportsSyncCodec, ) from zarr.core.common import concurrent_map from zarr.core.config import config @@ -68,6 +69,106 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value +@dataclass(frozen=True, slots=True) +class CodecChain: + """Codec chain with pre-resolved metadata specs. + + Constructed from an iterable of codecs and a chunk ArraySpec. + Resolves each codec against the spec so that encode/decode can + run without re-resolving. + """ + + codecs: tuple[Codec, ...] + chunk_spec: ArraySpec + + _aa_codecs: tuple[ArrayArrayCodec, ...] = field(init=False, repr=False, compare=False) + _aa_specs: tuple[ArraySpec, ...] = field(init=False, repr=False, compare=False) + _ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False) + _ab_spec: ArraySpec = field(init=False, repr=False, compare=False) + _bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False) + _bb_spec: ArraySpec = field(init=False, repr=False, compare=False) + _all_sync: bool = field(init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + aa, ab, bb = codecs_from_list(list(self.codecs)) + + aa_specs: list[ArraySpec] = [] + spec = self.chunk_spec + for aa_codec in aa: + aa_specs.append(spec) + spec = aa_codec.resolve_metadata(spec) + + object.__setattr__(self, "_aa_codecs", aa) + object.__setattr__(self, "_aa_specs", tuple(aa_specs)) + object.__setattr__(self, "_ab_codec", ab) + object.__setattr__(self, "_ab_spec", spec) + + spec = ab.resolve_metadata(spec) + object.__setattr__(self, "_bb_codecs", bb) + object.__setattr__(self, "_bb_spec", spec) + + object.__setattr__( + self, + "_all_sync", + all(isinstance(c, SupportsSyncCodec) for c in self.codecs), + ) + + @property + def all_sync(self) -> bool: + return self._all_sync + + def decode_chunk( + self, + chunk_bytes: Buffer, + ) -> NDBuffer: + """Decode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + """ + bb_out: Any = chunk_bytes + for bb_codec in reversed(self._bb_codecs): + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._bb_spec) + + ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec) + + for aa_codec, spec in zip(reversed(self._aa_codecs), reversed(self._aa_specs), strict=True): + ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + def encode_chunk( + self, + chunk_array: NDBuffer, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + """ + aa_out: Any = chunk_array + + for aa_codec, spec in zip(self._aa_codecs, self._aa_specs, strict=True): + if aa_out is None: + return None + aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) + + if aa_out is None: + return None + bb_out: Any = cast("SupportsSyncCodec", self._ab_codec)._encode_sync(aa_out, self._ab_spec) + + for bb_codec in self._bb_codecs: + if bb_out is None: + return None + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._bb_spec) + + return bb_out # type: ignore[no-any-return] + + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: + for codec in self.codecs: + byte_length = codec.compute_encoded_size(byte_length, array_spec) + array_spec = codec.resolve_metadata(array_spec) + return byte_length + + @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): """Default codec pipeline. diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..08c05864aa 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -5,7 +5,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import ( + ByteRequest, + Store, + SupportsDeleteSync, + SupportsGetSync, + SupportsSetSync, +) from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ( ANY_ACCESS_MODE, @@ -228,6 +234,37 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + # ------------------------------------------------------------------- + # Synchronous IO delegation + # ------------------------------------------------------------------- + + def get_sync( + self, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + """Synchronous read — delegates to ``self.store.get_sync(self.path, ...)``.""" + if not isinstance(self.store, SupportsGetSync): + raise TypeError(f"Store {type(self.store).__name__} does not support synchronous get.") + if prototype is None: + prototype = default_buffer_prototype() + return self.store.get_sync(self.path, prototype=prototype, byte_range=byte_range) + + def set_sync(self, value: Buffer) -> None: + """Synchronous write — delegates to ``self.store.set_sync(self.path, value)``.""" + if not isinstance(self.store, SupportsSetSync): + raise TypeError(f"Store {type(self.store).__name__} does not support synchronous set.") + self.store.set_sync(self.path, value) + + def delete_sync(self) -> None: + """Synchronous delete — delegates to ``self.store.delete_sync(self.path)``.""" + if not isinstance(self.store, SupportsDeleteSync): + raise TypeError( + f"Store {type(self.store).__name__} does not support synchronous delete." + ) + self.store.delete_sync(self.path) + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..96f1e61746 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -187,6 +187,56 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def _ensure_open_sync(self) -> None: + if not self._is_open: + if not self.read_only: + self.root.mkdir(parents=True, exist_ok=True) + if not self.root.exists(): + raise FileNotFoundError(f"{self.root} does not exist") + self._is_open = True + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + self._ensure_open_sync() + assert isinstance(key, str) + path = self.root / key + try: + return _get(path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._ensure_open_sync() + self._check_writable() + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"LocalStore.set(): `value` must be a Buffer instance. " + f"Got an instance of {type(value)} instead." + ) + path = self.root / key + _put(path, value) + + def delete_sync(self, key: str) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + async def get( self, key: str, diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..1194894b9d 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -77,6 +77,49 @@ def __eq__(self, other: object) -> bool: and self.read_only == other.read_only ) + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + try: + value = self._store_dict[key] + start, stop = _normalize_byte_range_index(value, byte_range) + return prototype.buffer.from_buffer(value[start:stop]) + except KeyError: + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + self._store_dict[key] = value + + def delete_sync(self, key: str) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + try: + del self._store_dict[key] + except KeyError: + logger.debug("Key %s does not exist.", key) + async def get( self, key: str, @@ -122,7 +165,6 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None raise TypeError( f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) - if byte_range is not None: buf = self._store_dict[key] buf[byte_range[0] : byte_range[1]] = value diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1b8e85ed98..ce83715b86 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from typing import Any - from zarr.abc.store import ByteRequest from zarr.core.buffer.core import BufferPrototype import pytest @@ -22,6 +21,9 @@ RangeByteRequest, Store, SuffixByteRequest, + SupportsDeleteSync, + SupportsGetSync, + SupportsSetSync, ) from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync @@ -39,6 +41,27 @@ class StoreTests(Generic[S, B]): store_cls: type[S] buffer_cls: type[B] + @staticmethod + def _require_get_sync(store: S) -> SupportsGetSync: + """Skip unless *store* implements :class:`SupportsGetSync`.""" + if not isinstance(store, SupportsGetSync): + pytest.skip("store does not implement SupportsGetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_set_sync(store: S) -> SupportsSetSync: + """Skip unless *store* implements :class:`SupportsSetSync`.""" + if not isinstance(store, SupportsSetSync): + pytest.skip("store does not implement SupportsSetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_delete_sync(store: S) -> SupportsDeleteSync: + """Skip unless *store* implements :class:`SupportsDeleteSync`.""" + if not isinstance(store, SupportsDeleteSync): + pytest.skip("store does not implement SupportsDeleteSync") + return store # type: ignore[unreachable] + @abstractmethod async def set(self, store: S, key: str, value: Buffer) -> None: """ @@ -579,6 +602,52 @@ def test_get_json_sync(self, store: S) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) assert store._get_json_sync(key, prototype=default_buffer_prototype()) == data + # ------------------------------------------------------------------- + # Synchronous store methods (SupportsSyncStore protocol) + # ------------------------------------------------------------------- + + def test_get_sync(self, store: S) -> None: + getter = self._require_get_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_get" + sync(self.set(store, key, data_buf)) + result = getter.get_sync(key) + assert result is not None + assert_bytes_equal(result, data_buf) + + def test_get_sync_missing(self, store: S) -> None: + getter = self._require_get_sync(store) + result = getter.get_sync("nonexistent") + assert result is None + + def test_set_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_set" + setter.set_sync(key, data_buf) + result = sync(self.get(store, key)) + assert_bytes_equal(result, data_buf) + + def test_delete_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + deleter = self._require_delete_sync(store) + getter = self._require_get_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_delete" + setter.set_sync(key, data_buf) + deleter.delete_sync(key) + result = getter.get_sync(key) + assert result is None + + def test_delete_sync_missing(self, store: S) -> None: + deleter = self._require_delete_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + # should not raise + deleter.delete_sync("nonexistent_sync") + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_codecs/test_blosc.py b/tests/test_codecs/test_blosc.py index 6f4821f8b1..0201beb8de 100644 --- a/tests/test_codecs/test_blosc.py +++ b/tests/test_codecs/test_blosc.py @@ -6,11 +6,12 @@ from packaging.version import Version import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.codecs import BloscCodec from zarr.codecs.blosc import BloscShuffle, Shuffle -from zarr.core.array_spec import ArraySpec +from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import default_buffer_prototype -from zarr.core.dtype import UInt16 +from zarr.core.dtype import UInt16, get_data_type_from_native_dtype from zarr.storage import MemoryStore, StorePath @@ -110,3 +111,27 @@ async def test_typesize() -> None: else: expected_size = 10216 assert size == expected_size, msg + + +def test_blosc_codec_supports_sync() -> None: + assert isinstance(BloscCodec(), SupportsSyncCodec) + + +def test_blosc_codec_sync_roundtrip() -> None: + codec = BloscCodec(typesize=8) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_crc32c.py b/tests/test_codecs/test_crc32c.py new file mode 100644 index 0000000000..3ab1070f60 --- /dev/null +++ b/tests/test_codecs/test_crc32c.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import numpy as np + +from zarr.abc.codec import SupportsSyncCodec +from zarr.codecs.crc32c_ import Crc32cCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype + + +def test_crc32c_codec_supports_sync() -> None: + assert isinstance(Crc32cCodec(), SupportsSyncCodec) + + +def test_crc32c_codec_sync_roundtrip() -> None: + codec = Crc32cCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_endian.py b/tests/test_codecs/test_endian.py index ab64afb1b8..c505cee828 100644 --- a/tests/test_codecs/test_endian.py +++ b/tests/test_codecs/test_endian.py @@ -4,8 +4,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import BytesCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -33,6 +37,31 @@ async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: assert np.array_equal(data, readback_data) +def test_bytes_codec_supports_sync() -> None: + assert isinstance(BytesCodec(), SupportsSyncCodec) + + +def test_bytes_codec_sync_roundtrip() -> None: + codec = BytesCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + codec = codec.evolve_from_array_spec(spec) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + @pytest.mark.filterwarnings("ignore:The endianness of the requested serializer") @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize("dtype_input_endian", [">u2", " None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_gzip_codec_supports_sync() -> None: + assert isinstance(GzipCodec(), SupportsSyncCodec) + + +def test_gzip_codec_sync_roundtrip() -> None: + codec = GzipCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_transpose.py b/tests/test_codecs/test_transpose.py index 06ec668ad3..949bb72a62 100644 --- a/tests/test_codecs/test_transpose.py +++ b/tests/test_codecs/test_transpose.py @@ -3,9 +3,13 @@ import zarr from zarr import AsyncArray, config +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import TransposeCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype from zarr.core.common import MemoryOrder +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -93,3 +97,27 @@ def test_transpose_invalid( chunk_key_encoding={"name": "v2", "separator": "."}, filters=[TransposeCodec(order=order)], # type: ignore[arg-type] ) + + +def test_transpose_codec_supports_sync() -> None: + assert isinstance(TransposeCodec(order=(0, 1)), SupportsSyncCodec) + + +def test_transpose_codec_sync_roundtrip() -> None: + codec = TransposeCodec(order=(1, 0)) + arr = np.arange(12, dtype="float64").reshape(3, 4) + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + resolved_spec = codec.resolve_metadata(spec) + decoded = codec._decode_sync(encoded, resolved_spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) diff --git a/tests/test_codecs/test_vlen.py b/tests/test_codecs/test_vlen.py index cf0905daca..f3445824b3 100644 --- a/tests/test_codecs/test_vlen.py +++ b/tests/test_codecs/test_vlen.py @@ -5,9 +5,10 @@ import zarr from zarr import Array -from zarr.abc.codec import Codec +from zarr.abc.codec import Codec, SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.core.dtype import get_data_type_from_native_dtype from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.core.metadata.v3 import ArrayV3Metadata @@ -62,3 +63,11 @@ def test_vlen_string( assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == get_data_type_from_native_dtype(data.dtype) assert a.dtype == data.dtype + + +def test_vlen_utf8_codec_supports_sync() -> None: + assert isinstance(VLenUTF8Codec(), SupportsSyncCodec) + + +def test_vlen_bytes_codec_supports_sync() -> None: + assert isinstance(VLenBytesCodec(), SupportsSyncCodec) diff --git a/tests/test_codecs/test_zstd.py b/tests/test_codecs/test_zstd.py index 6068f53443..3f3f15a41a 100644 --- a/tests/test_codecs/test_zstd.py +++ b/tests/test_codecs/test_zstd.py @@ -2,8 +2,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath @@ -23,3 +27,27 @@ def test_zstd(store: Store, checksum: bool) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_zstd_codec_supports_sync() -> None: + assert isinstance(ZstdCodec(), SupportsSyncCodec) + + +def test_zstd_codec_sync_roundtrip() -> None: + codec = ZstdCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..9c734fb0c3 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator + from zarr.abc.store import ByteRequest from zarr.core.buffer import BufferPrototype from zarr.core.buffer.core import Buffer @@ -83,6 +84,22 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None self.counter["__setitem__", key_suffix] += 1 return await super().set(key, value, byte_range) + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__getitem__", key_suffix] += 1 + return super().get_sync(key, prototype=prototype, byte_range=byte_range) + + def set_sync(self, key: str, value: Buffer) -> None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__setitem__", key_suffix] += 1 + return super().set_sync(key, value) + def test_normalize_integer_selection() -> None: assert 1 == normalize_integer_selection(1, 100) diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py new file mode 100644 index 0000000000..192479dc59 --- /dev/null +++ b/tests/test_sync_codec_pipeline.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.gzip import GzipCodec +from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.zstd import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.codec_pipeline import CodecChain +from zarr.core.dtype import get_data_type_from_native_dtype + + +def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec: + zdtype = get_data_type_from_native_dtype(dtype) + return ArraySpec( + shape=shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + + +def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer: + return default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + +class TestCodecChain: + def test_all_sync(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = CodecChain((BytesCodec(),), spec) + assert chain.all_sync is True + + def test_all_sync_with_compression(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = CodecChain((BytesCodec(), GzipCodec()), spec) + assert chain.all_sync is True + + def test_all_sync_full_chain(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), spec) + assert chain.all_sync is True + + def test_encode_decode_roundtrip_bytes_only(self) -> None: + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain = CodecChain((BytesCodec(),), spec) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_encode_decode_roundtrip_with_compression(self) -> None: + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain = CodecChain((BytesCodec(), GzipCodec(level=1)), spec) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_encode_decode_roundtrip_with_transpose(self) -> None: + arr = np.arange(12, dtype="float64").reshape(3, 4) + spec = _make_array_spec(arr.shape, arr.dtype) + chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), spec) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array())