Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def quote_identifier(identifier: str) -> str:
return '"{}"'.format(identifier.replace('"', '""'))


def _row_to_dict(keys: Sequence[str], row: Sequence[Any]) -> Dict[str, Any]:
"""
Convert a row plus column names to a dictionary.

Duplicate column names are suffixed with ``_2``, ``_3``... so values are
preserved instead of overwritten.
"""
counts: Dict[str, int] = {}
result: Dict[str, Any] = {}
for key, value in zip(keys, row):
count = counts.get(key, 0) + 1
counts[key] = count
final_key = key if count == 1 else "{}_{}".format(key, count)
result[final_key] = value
return result


try:
import pandas as pd # type: ignore
except ImportError:
Expand Down Expand Up @@ -548,7 +565,7 @@ def query(
cursor = self.execute(sql, params or tuple())
keys = [d[0] for d in cursor.description]
for row in cursor:
yield dict(zip(keys, row))
yield _row_to_dict(keys, row)

def execute(
self, sql: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None
Expand Down Expand Up @@ -1445,7 +1462,7 @@ def rows_where(
cursor = self.db.execute(sql, where_args or [])
columns = [c[0] for c in cursor.description]
for row in cursor:
yield dict(zip(columns, row))
yield _row_to_dict(columns, row)

def pks_and_rows_where(
self,
Expand Down Expand Up @@ -2862,7 +2879,7 @@ def search(
)
columns = [c[0] for c in cursor.description]
for row in cursor:
yield dict(zip(columns, row))
yield _row_to_dict(columns, row)

def value_or_default(self, key: str, value: Any) -> Any:
return self._defaults[key] if value is DEFAULT else value
Expand Down
15 changes: 15 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,18 @@ def test_execute_returning_dicts(fresh_db):
assert fresh_db.execute_returning_dicts("select * from test") == [
{"id": 1, "bar": 2}
]


def test_query_duplicate_output_columns_are_suffixed(fresh_db):
fresh_db.execute("create table one (id integer, value text)")
fresh_db.execute("create table two (id integer, value text)")
fresh_db["one"].insert({"id": 1, "value": "left"})
fresh_db["two"].insert({"id": 2, "value": "right"})

rows = list(
fresh_db.query(
"select one.id, two.id, one.value, two.value from one, two where one.id = 1"
)
)

assert rows == [{"id": 1, "id_2": 2, "value": "left", "value_2": "right"}]