From b0f5494407acd707e80ece8e0be852b1b0219b80 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 4 Mar 2026 13:40:57 -0800 Subject: [PATCH] Add bilateral filter for feature-preserving smoothing (#969) --- README.md | 1 + docs/source/reference/focal.rst | 7 + examples/user_guide/17_Bilateral_Filter.ipynb | 185 ++++++++++ xrspatial/__init__.py | 1 + xrspatial/accessor.py | 8 + xrspatial/bilateral.py | 333 ++++++++++++++++++ xrspatial/tests/test_bilateral.py | 279 +++++++++++++++ 7 files changed, 814 insertions(+) create mode 100644 examples/user_guide/17_Bilateral_Filter.ipynb create mode 100644 xrspatial/bilateral.py create mode 100644 xrspatial/tests/test_bilateral.py diff --git a/README.md b/README.md index 68d27b0f..319efe13 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e | [Emerging Hotspots](xrspatial/emerging_hotspots.py) | Classifies time-series hot/cold spot trends using Gi* and Mann-Kendall | ✅️ | ✅️ | ✅️ | ✅️ | | [Mean](xrspatial/focal.py) | Computes the mean value within a sliding neighborhood window | ✅️ | ✅️ | ✅️ | ✅️ | | [Focal Statistics](xrspatial/focal.py) | Computes summary statistics over a sliding neighborhood window | ✅️ | ✅️ | ✅️ | ✅️ | +| [Bilateral](xrspatial/bilateral.py) | Feature-preserving smoothing via bilateral filtering | ✅️ | ✅️ | ✅️ | ✅️ | ------- diff --git a/docs/source/reference/focal.rst b/docs/source/reference/focal.rst index e471463d..8b67c046 100644 --- a/docs/source/reference/focal.rst +++ b/docs/source/reference/focal.rst @@ -25,6 +25,13 @@ Mean xrspatial.focal.mean +Bilateral +========= +.. autosummary:: + :toctree: _autosummary + + xrspatial.bilateral.bilateral + Focal Statistics ================ diff --git a/examples/user_guide/17_Bilateral_Filter.ipynb b/examples/user_guide/17_Bilateral_Filter.ipynb new file mode 100644 index 00000000..24450db7 --- /dev/null +++ b/examples/user_guide/17_Bilateral_Filter.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bilateral filtering\n", + "\n", + "The bilateral filter smooths a raster while preserving edges. Unlike a simple mean filter, it weights each neighbor by both spatial distance and value similarity. Pixels across a sharp boundary contribute very little, so edges stay sharp while flat areas get smoothed.\n", + "\n", + "Two parameters control the behavior:\n", + "- **sigma_spatial**: how far the spatial Gaussian reaches (kernel radius = ceil(2 * sigma_spatial))\n", + "- **sigma_range**: how much value difference is tolerated before a neighbor gets downweighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from xrspatial import bilateral\n", + "from xrspatial import mean\n", + "from xrspatial.terrain import generate_terrain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate a synthetic terrain with noise\n", + "\n", + "We'll create a DEM, add Gaussian noise, and then compare bilateral filtering against the standard mean filter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "W, H = 600, 400\n", + "cvs_terrain = xr.DataArray(\n", + " np.zeros((H, W)),\n", + " dims=['y', 'x'],\n", + " coords={'y': np.linspace(0, 100, H), 'x': np.linspace(0, 150, W)},\n", + ")\n", + "terrain = generate_terrain(cvs_terrain, seed=42)\n", + "\n", + "# Add noise\n", + "rng = np.random.default_rng(123)\n", + "noise = rng.normal(0, 15, terrain.shape)\n", + "noisy_terrain = terrain.copy(data=terrain.values + noise)\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "terrain.plot(ax=axes[0], cmap='terrain')\n", + "axes[0].set_title('Clean terrain')\n", + "noisy_terrain.plot(ax=axes[1], cmap='terrain')\n", + "axes[1].set_title('Noisy terrain')\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare bilateral vs. mean filter\n", + "\n", + "The mean filter blurs edges. The bilateral filter preserves them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "smoothed_bilateral = bilateral(noisy_terrain, sigma_spatial=2.0, sigma_range=20.0)\n", + "smoothed_mean = mean(noisy_terrain, passes=3)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", + "\n", + "noisy_terrain.plot(ax=axes[0], cmap='terrain')\n", + "axes[0].set_title('Noisy input')\n", + "\n", + "smoothed_mean.plot(ax=axes[1], cmap='terrain')\n", + "axes[1].set_title('Mean filter (3 passes)')\n", + "\n", + "smoothed_bilateral.plot(ax=axes[2], cmap='terrain')\n", + "axes[2].set_title('Bilateral filter')\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Effect of sigma_range\n", + "\n", + "Smaller `sigma_range` preserves more edges; larger values allow smoothing across bigger value differences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sigma_ranges = [5.0, 20.0, 100.0]\n", + "\n", + "fig, axes = plt.subplots(1, len(sigma_ranges), figsize=(18, 5))\n", + "for ax, sr in zip(axes, sigma_ranges):\n", + " result = bilateral(noisy_terrain, sigma_spatial=2.0, sigma_range=sr)\n", + " result.plot(ax=ax, cmap='terrain')\n", + " ax.set_title(f'sigma_range = {sr}')\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step-edge preservation\n", + "\n", + "A clear demonstration: a raster with a sharp vertical edge. The bilateral filter keeps the boundary; the mean filter blurs it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "step = np.zeros((50, 100))\n", + "step[:, 50:] = 100.0\n", + "\n", + "# Add a bit of noise\n", + "step_noisy = step + rng.normal(0, 5, step.shape)\n", + "step_agg = xr.DataArray(step_noisy, dims=['y', 'x'])\n", + "\n", + "step_bilateral = bilateral(step_agg, sigma_spatial=2.0, sigma_range=10.0)\n", + "step_mean = mean(step_agg, passes=3)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n", + "step_agg.plot(ax=axes[0], cmap='gray')\n", + "axes[0].set_title('Noisy step edge')\n", + "step_mean.plot(ax=axes[1], cmap='gray')\n", + "axes[1].set_title('Mean filter')\n", + "step_bilateral.plot(ax=axes[2], cmap='gray')\n", + "axes[2].set_title('Bilateral filter')\n", + "plt.tight_layout()\n", + "\n", + "# Cross-section\n", + "row = 25\n", + "fig, ax = plt.subplots(figsize=(10, 4))\n", + "ax.plot(step_agg.data[row], label='Noisy', alpha=0.5)\n", + "ax.plot(step_mean.data[row], label='Mean', linewidth=2)\n", + "ax.plot(step_bilateral.data[row], label='Bilateral', linewidth=2)\n", + "ax.legend()\n", + "ax.set_xlabel('Column')\n", + "ax.set_ylabel('Value')\n", + "ax.set_title('Cross-section at row 25')\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/xrspatial/__init__.py b/xrspatial/__init__.py index 7f2ba46e..ef970847 100644 --- a/xrspatial/__init__.py +++ b/xrspatial/__init__.py @@ -1,5 +1,6 @@ from xrspatial.aspect import aspect # noqa from xrspatial.balanced_allocation import balanced_allocation # noqa +from xrspatial.bilateral import bilateral # noqa from xrspatial.bump import bump # noqa from xrspatial.cost_distance import cost_distance # noqa from xrspatial.dasymetric import disaggregate # noqa diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 837b3864..44342778 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -173,6 +173,10 @@ def focal_mean(self, **kwargs): from .focal import mean return mean(self._obj, **kwargs) + def bilateral(self, **kwargs): + from .bilateral import bilateral + return bilateral(self._obj, **kwargs) + # ---- Proximity / Distance ---- def proximity(self, **kwargs): @@ -501,6 +505,10 @@ def focal_mean(self, **kwargs): from .focal import mean return mean(self._obj, **kwargs) + def bilateral(self, **kwargs): + from .bilateral import bilateral + return bilateral(self._obj, **kwargs) + # ---- Diffusion ---- def diffuse(self, **kwargs): diff --git a/xrspatial/bilateral.py b/xrspatial/bilateral.py new file mode 100644 index 00000000..474e6b6c --- /dev/null +++ b/xrspatial/bilateral.py @@ -0,0 +1,333 @@ +"""Bilateral filter for feature-preserving smoothing. + +Smooths a raster while preserving edges by weighting neighbors based on +both spatial distance and value similarity. + +Reference +--------- +Tomasi & Manduchi, "Bilateral Filtering for Gray and Color Images," ICCV 1998. +""" +from __future__ import annotations + +from functools import partial +from math import ceil, exp, isnan + +import numpy as np +import xarray as xr +from numba import cuda +from xarray import DataArray + +try: + import dask.array as da +except ImportError: + da = None + +try: + import cupy +except ImportError: + class cupy(object): + ndarray = False + +from xrspatial.utils import ( + ArrayTypeFunctionMapping, + _boundary_to_dask, + _pad_array, + _validate_boundary, + _validate_raster, + _validate_scalar, + cuda_args, + ngjit, +) +from xrspatial.dataset_support import supports_dataset + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _kernel_radius(sigma_spatial): + """Derive kernel radius from sigma_spatial: ceil(2 * sigma).""" + return int(ceil(2.0 * sigma_spatial)) + + +# --------------------------------------------------------------------------- +# NumPy backend +# --------------------------------------------------------------------------- + +@ngjit +def _bilateral_numpy(data, radius, sigma_spatial, sigma_range): + rows, cols = data.shape + out = np.empty_like(data) + inv_2_ss = 1.0 / (2.0 * sigma_spatial * sigma_spatial) + inv_2_sr = 1.0 / (2.0 * sigma_range * sigma_range) + + for y in range(rows): + for x in range(cols): + center = data[y, x] + if isnan(center): + out[y, x] = np.nan + continue + + w_sum = 0.0 + v_sum = 0.0 + + y0 = max(y - radius, 0) + y1 = min(y + radius + 1, rows) + x0 = max(x - radius, 0) + x1 = min(x + radius + 1, cols) + + for ny in range(y0, y1): + for nx in range(x0, x1): + val = data[ny, nx] + if isnan(val): + continue + dy = ny - y + dx = nx - x + dist2 = dy * dy + dx * dx + diff = val - center + range2 = diff * diff + + w = exp(-dist2 * inv_2_ss - range2 * inv_2_sr) + w_sum += w + v_sum += w * val + + if w_sum > 0.0: + out[y, x] = v_sum / w_sum + else: + out[y, x] = np.nan + + return out + + +def _bilateral_numpy_boundary(data, radius, sigma_spatial, sigma_range, + boundary='nan'): + if boundary == 'nan': + return _bilateral_numpy(data, radius, sigma_spatial, sigma_range) + padded = _pad_array(data, radius, boundary) + result = _bilateral_numpy(padded, radius, sigma_spatial, sigma_range) + return result[radius:-radius, radius:-radius] + + +# --------------------------------------------------------------------------- +# Dask + NumPy backend +# --------------------------------------------------------------------------- + +def _bilateral_dask_numpy(data, radius, sigma_spatial, sigma_range, + boundary='nan'): + _func = partial( + _bilateral_numpy, radius=radius, + sigma_spatial=sigma_spatial, sigma_range=sigma_range, + ) + out = data.map_overlap( + _func, + depth=(radius, radius), + boundary=_boundary_to_dask(boundary), + meta=np.array(()), + ) + return out + + +# --------------------------------------------------------------------------- +# CuPy (GPU) backend +# --------------------------------------------------------------------------- + +@cuda.jit +def _bilateral_gpu(data, radius, inv_2_ss, inv_2_sr, out): + x, y = cuda.grid(2) + rows, cols = data.shape + + if 0 <= x < cols and 0 <= y < rows: + center = data[y, x] + if isnan(center): + out[y, x] = center + return + + w_sum = 0.0 + v_sum = 0.0 + + y0 = max(y - radius, 0) + y1 = min(y + radius + 1, rows) + x0 = max(x - radius, 0) + x1 = min(x + radius + 1, cols) + + for ny in range(y0, y1): + for nx in range(x0, x1): + val = data[ny, nx] + if not isnan(val): + dy = ny - y + dx = nx - x + dist2 = dy * dy + dx * dx + diff = val - center + range2 = diff * diff + w = exp(-dist2 * inv_2_ss - range2 * inv_2_sr) + w_sum += w + v_sum += w * val + + if w_sum > 0.0: + out[y, x] = v_sum / w_sum + else: + out[y, x] = center + + +def _bilateral_cupy(data, radius, sigma_spatial, sigma_range): + data_cu = cupy.asarray(data, dtype=cupy.float64) + inv_2_ss = 1.0 / (2.0 * sigma_spatial * sigma_spatial) + inv_2_sr = 1.0 / (2.0 * sigma_range * sigma_range) + + griddim, blockdim = cuda_args(data_cu.shape) + out = cupy.empty_like(data_cu) + + _bilateral_gpu[griddim, blockdim]( + data_cu, radius, inv_2_ss, inv_2_sr, out, + ) + return out + + +# --------------------------------------------------------------------------- +# Dask + CuPy backend +# --------------------------------------------------------------------------- + +def _bilateral_dask_cupy(data, radius, sigma_spatial, sigma_range, + boundary='nan'): + _func = partial( + _bilateral_cupy, radius=radius, + sigma_spatial=sigma_spatial, sigma_range=sigma_range, + ) + out = data.map_overlap( + _func, + depth=(radius, radius), + boundary=_boundary_to_dask(boundary, is_cupy=True), + meta=cupy.array(()), + ) + return out + + +# --------------------------------------------------------------------------- +# dispatcher +# --------------------------------------------------------------------------- + +def _bilateral(data, radius, sigma_spatial, sigma_range, boundary='nan'): + agg = xr.DataArray(data) + mapper = ArrayTypeFunctionMapping( + numpy_func=partial( + _bilateral_numpy_boundary, + boundary=boundary, + ), + cupy_func=_bilateral_cupy, + dask_func=partial( + _bilateral_dask_numpy, + boundary=boundary, + ), + dask_cupy_func=partial( + _bilateral_dask_cupy, + boundary=boundary, + ), + ) + out = mapper(agg)( + agg.data, + radius=radius, + sigma_spatial=sigma_spatial, + sigma_range=sigma_range, + ) + return out + + +# --------------------------------------------------------------------------- +# public API +# --------------------------------------------------------------------------- + +@supports_dataset +def bilateral(agg, sigma_spatial=1.0, sigma_range=10.0, + name='bilateral', boundary='nan'): + """Apply a bilateral filter for feature-preserving smoothing. + + Smooths a raster while preserving edges. Each pixel is replaced by + a weighted average of its neighbours, where the weight depends on + both the spatial distance and the value difference between the + neighbour and the center pixel. Neighbours that are far away *or* + very different in value contribute little, so edges stay sharp. + + Parameters + ---------- + agg : xarray.DataArray or xr.Dataset + 2D (or 3D multi-band) array of input values. + Supports NumPy, CuPy, Dask+NumPy, and Dask+CuPy backends. + sigma_spatial : float, default 1.0 + Standard deviation of the spatial Gaussian. Controls the size + of the neighbourhood: kernel radius = ceil(2 * sigma_spatial). + Must be > 0. + sigma_range : float, default 10.0 + Standard deviation of the range (value-similarity) Gaussian. + Larger values allow more smoothing across value differences; + smaller values preserve more edges. Must be > 0. + name : str, default 'bilateral' + Name for the output DataArray. + boundary : str, default 'nan' + How to handle edges where the kernel extends beyond the raster. + ``'nan'`` -- fill missing neighbours with NaN (default). + ``'nearest'`` -- repeat edge values. + ``'reflect'`` -- mirror at boundary. + ``'wrap'`` -- periodic / toroidal. + + Returns + ------- + out : xarray.DataArray or xr.Dataset + Filtered array of the same shape, dtype, dims, and coords as + the input. + + Examples + -------- + .. sourcecode:: python + + >>> import numpy as np + >>> import xarray as xr + >>> from xrspatial import bilateral + >>> data = np.array([ + ... [0., 0., 0., 100., 100.], + ... [0., 0., 0., 100., 100.], + ... [0., 0., 0., 100., 100.], + ... [0., 0., 0., 100., 100.], + ... [0., 0., 0., 100., 100.]]) + >>> raster = xr.DataArray(data) + >>> smoothed = bilateral(raster, sigma_spatial=1.0, sigma_range=5.0) + + References + ---------- + Tomasi & Manduchi, "Bilateral Filtering for Gray and Color Images," + ICCV 1998. + """ + _validate_raster(agg, func_name='bilateral', name='agg', ndim=(2, 3)) + _validate_scalar(sigma_spatial, func_name='bilateral', + name='sigma_spatial', dtype=(int, float), + min_val=0, min_exclusive=True) + _validate_scalar(sigma_range, func_name='bilateral', + name='sigma_range', dtype=(int, float), + min_val=0, min_exclusive=True) + _validate_boundary(boundary) + + sigma_spatial = float(sigma_spatial) + sigma_range = float(sigma_range) + radius = _kernel_radius(sigma_spatial) + + if agg.ndim == 3: + from xrspatial.focal import _apply_per_band + return _apply_per_band( + bilateral, agg, + sigma_spatial=sigma_spatial, + sigma_range=sigma_range, + name=name, + boundary=boundary, + ) + + out = _bilateral( + agg.data.astype(float), + radius, sigma_spatial, sigma_range, boundary, + ) + + return DataArray( + out, + name=name, + dims=agg.dims, + coords=agg.coords, + attrs=agg.attrs, + ) diff --git a/xrspatial/tests/test_bilateral.py b/xrspatial/tests/test_bilateral.py new file mode 100644 index 00000000..35fb808e --- /dev/null +++ b/xrspatial/tests/test_bilateral.py @@ -0,0 +1,279 @@ +"""Tests for xrspatial.bilateral.""" +try: + import dask.array as da +except ImportError: + da = None + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.bilateral import bilateral, _bilateral_numpy, _kernel_radius +from xrspatial.tests.general_checks import ( + assert_boundary_mode_correctness, + create_test_raster, + cuda_and_cupy_available, + dask_array_available, + general_output_checks, +) + + +# ---- fixtures ---- + +@pytest.fixture +def step_edge(): + """5x10 raster with a sharp vertical edge: left=0, right=100.""" + data = np.zeros((5, 10), dtype=np.float64) + data[:, 5:] = 100.0 + return data + + +@pytest.fixture +def random_data_969(): + rng = np.random.default_rng(969) + return rng.standard_normal((30, 30)).astype(np.float64) + + +# ---- unit tests ---- + +def test_kernel_radius(): + assert _kernel_radius(1.0) == 2 + assert _kernel_radius(0.5) == 1 + assert _kernel_radius(2.0) == 4 + assert _kernel_radius(1.5) == 3 + + +def test_bilateral_preserves_flat_surface(): + """Flat raster should be unchanged after bilateral filtering.""" + data = np.full((8, 8), 42.0) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=10.0) + np.testing.assert_allclose(result.data, 42.0) + + +def test_bilateral_preserves_edge(step_edge): + """With a small sigma_range, the sharp edge should remain sharp.""" + agg = xr.DataArray(step_edge) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=1.0) + + # Interior left columns (far from edge) should stay near 0 + np.testing.assert_allclose(result.data[:, 0:3], 0.0, atol=1e-10) + + # Interior right columns (far from edge) should stay near 100 + np.testing.assert_allclose(result.data[:, 7:10], 100.0, atol=1e-10) + + +def test_bilateral_smooths_with_large_sigma_range(step_edge): + """With a large sigma_range the filter should blur across the edge.""" + agg = xr.DataArray(step_edge) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=1000.0) + + # Column 4 (just left of edge) should now be pulled toward 100 + col4_mean = result.data[2, 4] + assert col4_mean > 10.0, f"Expected blurring but got {col4_mean}" + + # Column 5 (just right of edge) should be pulled toward 0 + col5_mean = result.data[2, 5] + assert col5_mean < 90.0, f"Expected blurring but got {col5_mean}" + + +def test_bilateral_nan_handling(): + """NaN center pixels should stay NaN; NaN neighbors should be skipped.""" + data = np.array([ + [1.0, 2.0, 3.0], + [4.0, np.nan, 6.0], + [7.0, 8.0, 9.0], + ]) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=100.0) + + # Center pixel stays NaN + assert np.isnan(result.data[1, 1]) + + # Non-NaN pixels should produce finite values + assert np.all(np.isfinite(result.data[~np.isnan(data)])) + + +def test_bilateral_all_nan(): + """All-NaN raster should remain all-NaN.""" + data = np.full((4, 4), np.nan) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=10.0) + assert np.all(np.isnan(result.data)) + + +def test_bilateral_single_cell(): + """Single-cell raster should return same value.""" + data = np.array([[7.0]]) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=10.0) + np.testing.assert_allclose(result.data, 7.0) + + +def test_bilateral_output_checks(random_data_969): + """Output should preserve shape, dims, coords, attrs.""" + agg = create_test_raster(random_data_969) + result = bilateral(agg) + general_output_checks(agg, result) + + +def test_bilateral_validation_errors(): + """Check that invalid inputs raise appropriate errors.""" + data = np.ones((5, 5)) + agg = xr.DataArray(data) + + with pytest.raises(TypeError): + bilateral("not_a_dataarray") + + with pytest.raises(ValueError): + bilateral(agg, sigma_spatial=-1.0) + + with pytest.raises(ValueError): + bilateral(agg, sigma_range=0.0) + + with pytest.raises(ValueError): + bilateral(agg, boundary='invalid') + + +def test_bilateral_known_values(): + """Verify output against hand-computed bilateral filter on a 3x3 raster.""" + data = np.array([ + [10.0, 10.0, 10.0], + [10.0, 10.0, 50.0], + [10.0, 10.0, 10.0], + ]) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=5.0) + + # Center pixel (1,1): neighbors are mostly 10, one is 50. + # With sigma_range=5, the 50-value neighbor has very low weight + # (diff=40, exp(-40^2/(2*25)) ~ 0), so result should be near 10. + assert abs(result.data[1, 1] - 10.0) < 1.0 + + # The 50-value pixel (1,2): all its non-NaN neighbors are 10 except itself. + # The 10-value neighbors have diff=40 so low weight. + # Self (distance=0) has full weight. Result should stay near 50. + assert abs(result.data[1, 2] - 50.0) < 1.0 + + +def test_bilateral_symmetry(): + """Symmetric input should produce symmetric output.""" + data = np.array([ + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 100., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + ]) + agg = xr.DataArray(data) + result = bilateral(agg, sigma_spatial=1.0, sigma_range=50.0) + r = result.data + + # Should be symmetric around center + np.testing.assert_allclose(r, r[::-1, :], atol=1e-12) + np.testing.assert_allclose(r, r[:, ::-1], atol=1e-12) + + +# ---- cross-backend tests ---- + +def test_bilateral_cpu(random_data_969): + numpy_agg = create_test_raster(random_data_969) + result = bilateral(numpy_agg) + general_output_checks(numpy_agg, result) + + +@dask_array_available +def test_bilateral_dask_cpu(random_data_969): + numpy_agg = create_test_raster(random_data_969) + numpy_result = bilateral(numpy_agg) + + dask_agg = create_test_raster(random_data_969, backend='dask+numpy', + chunks=(15, 15)) + dask_result = bilateral(dask_agg) + general_output_checks(dask_agg, dask_result) + + np.testing.assert_allclose( + numpy_result.data, dask_result.data.compute(), + equal_nan=True, rtol=1e-6, + ) + + +@cuda_and_cupy_available +def test_bilateral_gpu_equals_cpu(random_data_969): + import cupy + + numpy_agg = create_test_raster(random_data_969) + numpy_result = bilateral(numpy_agg) + + cupy_agg = create_test_raster(random_data_969, backend='cupy') + cupy_result = bilateral(cupy_agg) + general_output_checks(cupy_agg, cupy_result) + + np.testing.assert_allclose( + numpy_result.data, cupy_result.data.get(), + equal_nan=True, rtol=1e-6, + ) + + +@dask_array_available +@cuda_and_cupy_available +def test_bilateral_dask_gpu(random_data_969): + import cupy + + numpy_agg = create_test_raster(random_data_969) + numpy_result = bilateral(numpy_agg) + + dask_cupy_agg = create_test_raster(random_data_969, + backend='dask+cupy', + chunks=(15, 15)) + dask_cupy_result = bilateral(dask_cupy_agg) + general_output_checks(dask_cupy_agg, dask_cupy_result) + + np.testing.assert_allclose( + numpy_result.data, dask_cupy_result.data.compute().get(), + equal_nan=True, rtol=1e-4, + ) + + +# ---- boundary mode tests ---- + +@dask_array_available +def test_bilateral_boundary_modes(random_data_969): + numpy_agg = create_test_raster(random_data_969) + dask_agg = create_test_raster(random_data_969, backend='dask+numpy', + chunks=(15, 15)) + assert_boundary_mode_correctness( + numpy_agg, dask_agg, bilateral, + depth=2, rtol=1e-5, nan_edges=False, + ) + + +# ---- 3D multi-band test ---- + +def test_bilateral_3d(): + band0 = np.random.default_rng(100).standard_normal((10, 10)) + band1 = np.random.default_rng(200).standard_normal((10, 10)) + data = np.stack([band0, band1]) + agg = xr.DataArray(data, dims=['band', 'y', 'x']) + + result = bilateral(agg, sigma_spatial=1.0, sigma_range=10.0) + assert result.shape == agg.shape + + # Each band should match independent 2D processing + for i in range(2): + band_agg = xr.DataArray(data[i]) + band_result = bilateral(band_agg, sigma_spatial=1.0, sigma_range=10.0) + np.testing.assert_allclose( + result.data[i], band_result.data, equal_nan=True, + ) + + +# ---- accessor test ---- + +def test_bilateral_accessor(): + import xrspatial # noqa - registers accessor + data = np.random.default_rng(42).standard_normal((10, 10)) + agg = xr.DataArray(data, dims=['y', 'x']) + result = agg.xrs.bilateral(sigma_spatial=1.0, sigma_range=10.0) + assert result.shape == agg.shape + assert isinstance(result, xr.DataArray)