Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepnote_toolkit/sql/sql_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ def upload_sql_cache(dataframe, upload_url):
with tempfile.TemporaryFile() as temp_file:
try:
dataframe.to_parquet(temp_file)
except (ArrowNotImplementedError, ArrowInvalid):
except (ArrowNotImplementedError, ArrowInvalid, OverflowError):
# see NB-1684
# we fallback to pickle if parquet serialization fails (which will throw either of these 2 errors)
# we fallback to pickle if parquet serialization fails (which will throw either of first 2 errors)
# OverflowError: PyArrow raises this for Python int / Decimal values exceeding int64 range
temp_file.seek(0)
temp_file.truncate()
dataframe.to_pickle(temp_file)

temp_file.seek(0)
Expand Down
15 changes: 10 additions & 5 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
import warnings
import weakref
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import quote

Expand Down Expand Up @@ -686,6 +687,14 @@ class BigQueryCredentialsError(Exception):
return {"connect_args": {"client": client}}


def _is_large_number(x: Any) -> bool:
"""Return True if *x* is a numeric value that exceeds the int64 range"""
try:
return isinstance(x, (int, float, Decimal)) and abs(x) > 2**63 - 1
except (TypeError, OverflowError, ArithmeticError):
return False


def _sanitize_dataframe_for_parquet(dataframe):
"""Sanitizes the dataframe so that we can safely call .to_parquet on it"""

Expand All @@ -707,11 +716,7 @@ def _sanitize_dataframe_for_parquet(dataframe):

# Convert columns with large numbers to strings
for column in dataframe.columns:
if (
dataframe[column]
.apply(lambda x: isinstance(x, (int, float)) and abs(x) > 2**63 - 1)
.any()
):
if dataframe[column].apply(_is_large_number).any():
dataframe[column] = dataframe[column].astype(str)


Expand Down
65 changes: 64 additions & 1 deletion tests/unit/test_sql_caching.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import unittest
from unittest import mock
from unittest.mock import patch

import pandas as pd
from parameterized import parameterized
from pyarrow import ArrowInvalid

from deepnote_toolkit.sql.sql_caching import _generate_cache_key, get_sql_cache
from deepnote_toolkit.sql.sql_caching import (
_generate_cache_key,
get_sql_cache,
upload_sql_cache,
)
from deepnote_toolkit.sql.sql_utils import is_single_select_query


Expand Down Expand Up @@ -331,3 +336,61 @@ def test_read_from_cache_error_doesnt_raise(

self.assertIsNone(result_df)
self.assertIsNone(upload_url)


class TestUploadSqlCache(unittest.TestCase):
@patch("deepnote_toolkit.sql.sql_caching.requests.put")
def test_upload_parquet_success(self, mock_put):
mock_put.return_value = mock.Mock(raise_for_status=mock.Mock())
df = pd.DataFrame({"a": [1, 2, 3]})

upload_sql_cache(df, "https://example.com/upload")

mock_put.assert_called_once()
args, _ = mock_put.call_args
self.assertEqual(args[0], "https://example.com/upload")

@patch("deepnote_toolkit.sql.sql_caching.requests.put")
def test_overflow_error_falls_back_to_pickle(self, mock_put):
"""Large Python int triggers OverflowError in to_parquet, upload succeeds via pickle."""
uploaded_bytes = None

def capture_put(_url, data):
nonlocal uploaded_bytes
uploaded_bytes = data.read()
return mock.Mock(raise_for_status=mock.Mock())

mock_put.side_effect = capture_put
df = pd.DataFrame({"x": pd.array([2**100, 1], dtype=object)})

upload_sql_cache(df, "https://example.com/upload")

roundtripped = pd.read_pickle(pd.io.common.BytesIO(uploaded_bytes))
pd.testing.assert_frame_equal(roundtripped, df)

@patch("deepnote_toolkit.sql.sql_caching.requests.put")
def test_pickle_fallback_truncates_partial_parquet_bytes(self, mock_put):
"""When to_parquet writes partial bytes before failing, truncate clears them."""
mock_put.return_value = mock.Mock(raise_for_status=mock.Mock())

def write_garbage_then_overflow(f, **_kwargs):
f.write(b"partial parquet data")
raise OverflowError("Python int too large")

pickle_pos = None
pickle_size = None

def capture_file_state(f, **_kwargs):
nonlocal pickle_pos, pickle_size
pickle_pos = f.tell()
pickle_size = f.seek(0, 2)
f.seek(0)

df = mock.Mock()
df.to_parquet.side_effect = write_garbage_then_overflow
df.to_pickle.side_effect = capture_file_state

upload_sql_cache(df, "https://example.com/upload")

self.assertEqual(pickle_pos, 0, "file should be at position 0")
self.assertEqual(pickle_size, 0, "file should be empty after truncate")
61 changes: 61 additions & 0 deletions tests/unit/test_sql_execution_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,67 @@ def test_sanitize_dataframe_for_parquet_conversions():
assert pd.api.types.is_integer_dtype(data["i"]) is True


def test_sanitize_dataframe_for_parquet_decimal_large_numbers():
"""Large decimal.Decimal values must be converted to strings."""
from decimal import Decimal

data = pd.DataFrame(
{
"d": [Decimal("99999999999999999999999999999999"), Decimal("1.5")],
"i": [1, 2],
}
)
se._sanitize_dataframe_for_parquet(data)
assert data["d"].dtype == object
assert data["d"].iloc[0] == "99999999999999999999999999999999"
assert data["d"].iloc[1] == "1.5"
assert pd.api.types.is_integer_dtype(data["i"]) is True


def test_sanitize_dataframe_for_parquet_decimal_small_numbers():
"""Decimal values within int64 range should not be converted."""
from decimal import Decimal

data = pd.DataFrame(
{
"d": [Decimal("100"), Decimal("200")],
}
)
se._sanitize_dataframe_for_parquet(data)
assert data["d"].iloc[0] == Decimal("100")


def test_sanitize_dataframe_for_parquet_decimal_nan():
"""Decimal('NaN') must not crash the sanitizer."""
from decimal import Decimal

data = pd.DataFrame(
{
"d": [Decimal("NaN"), Decimal("42")],
}
)
se._sanitize_dataframe_for_parquet(data)
assert data["d"].iloc[1] == Decimal("42")


def test_is_large_number():
from decimal import Decimal

assert se._is_large_number(2**63) is True
assert se._is_large_number(-(2**63) - 1) is True
assert se._is_large_number(2**63 - 1) is False
assert se._is_large_number(42) is False
assert se._is_large_number(float("inf")) is True
assert se._is_large_number(float("nan")) is False
assert se._is_large_number(Decimal("1e40")) is True
assert se._is_large_number(Decimal("100")) is False
assert se._is_large_number(Decimal("NaN")) is False
assert se._is_large_number(Decimal("sNaN")) is False
assert se._is_large_number(Decimal("Infinity")) is True
assert se._is_large_number("not a number") is False
assert se._is_large_number(None) is False


def test_create_sql_ssh_uri_no_ssh():
with se._create_sql_ssh_uri(False, {}) as url:
assert url is None
Expand Down