Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions examples/spanish/workflow_hitl_checkpoint_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import asyncio
import os
import pickle # noqa: S403
from dataclasses import dataclass
from typing import Any

import psycopg
from psycopg.types.json import Jsonb
from agent_framework import (
Agent,
AgentExecutor,
Expand All @@ -29,10 +29,6 @@
response_handler,
)
from agent_framework import WorkflowCheckpoint

# Importación privada — aún no hay API pública para la codificación de checkpoints.
# Ver: https://github.com/microsoft/agent-framework/issues/4428
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from agent_framework.exceptions import WorkflowCheckpointException
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
Expand All @@ -51,7 +47,10 @@ class PostgresCheckpointStorage:
"""Almacenamiento de checkpoints respaldado por PostgreSQL.

Guarda checkpoints en una sola tabla con columnas para ID, nombre del workflow,
timestamp y los datos JSON codificados. SQL maneja la indexación y el filtrado.
timestamp y los datos serializados con pickle. SQL maneja la indexación y el filtrado.

ADVERTENCIA DE SEGURIDAD: Los checkpoints usan pickle para la serialización.
Solo carga checkpoints de fuentes confiables.
"""

def __init__(self, conninfo: str) -> None:
Expand All @@ -65,7 +64,7 @@ def _ensure_table(self) -> None:
id TEXT PRIMARY KEY,
workflow_name TEXT NOT NULL,
timestamp TEXT NOT NULL,
data JSONB NOT NULL
data BYTEA NOT NULL
)
""")
conn.execute("""
Expand All @@ -75,14 +74,14 @@ def _ensure_table(self) -> None:

async def save(self, checkpoint: WorkflowCheckpoint) -> str:
"""Guarda un checkpoint en PostgreSQL."""
encoded = encode_checkpoint_value(checkpoint.to_dict())
data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301
async with await psycopg.AsyncConnection.connect(self._conninfo) as conn:
await conn.execute(
"""INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data)
VALUES (%s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""",
(checkpoint.checkpoint_id, checkpoint.workflow_name,
checkpoint.timestamp, Jsonb(encoded)),
checkpoint.timestamp, data),
)
return checkpoint.checkpoint_id

Expand All @@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint:
)).fetchone()
if row is None:
raise WorkflowCheckpointException(f"No se encontró checkpoint con ID {checkpoint_id}")
decoded = decode_checkpoint_value(row["data"])
return WorkflowCheckpoint.from_dict(decoded)
return pickle.loads(row["data"]) # noqa: S301

async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
"""Lista todos los checkpoints de un workflow, ordenados por timestamp."""
Expand All @@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi
"SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp",
(workflow_name,),
)).fetchall()
return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows]
return [pickle.loads(r["data"]) for r in rows] # noqa: S301

async def delete(self, checkpoint_id: str) -> bool:
"""Elimina un checkpoint por ID."""
Expand All @@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
)).fetchone()
if row is None:
return None
return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"]))
return pickle.loads(row["data"]) # noqa: S301

async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]:
"""Lista los IDs de checkpoints de un workflow."""
Expand Down
24 changes: 11 additions & 13 deletions examples/workflow_hitl_checkpoint_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import asyncio
import os
import pickle # noqa: S403
from dataclasses import dataclass
from typing import Any

import psycopg
from psycopg.types.json import Jsonb
from agent_framework import (
Agent,
AgentExecutor,
Expand All @@ -29,10 +29,6 @@
response_handler,
)
from agent_framework import WorkflowCheckpoint

# Private import — no public API for checkpoint encoding yet.
# See: https://github.com/microsoft/agent-framework/issues/4428
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from agent_framework.exceptions import WorkflowCheckpointException
from agent_framework.openai import OpenAIChatClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
Expand All @@ -51,7 +47,10 @@ class PostgresCheckpointStorage:
"""PostgreSQL-backed checkpoint storage.

Stores checkpoints in a single table with columns for ID, workflow name,
timestamp, and the encoded JSON data. SQL handles indexing and filtering.
timestamp, and the pickled checkpoint data. SQL handles indexing and filtering.

SECURITY WARNING: Checkpoints use pickle for serialization. Only load
checkpoints from trusted sources.
"""

def __init__(self, conninfo: str) -> None:
Expand All @@ -65,7 +64,7 @@ def _ensure_table(self) -> None:
id TEXT PRIMARY KEY,
workflow_name TEXT NOT NULL,
timestamp TEXT NOT NULL,
data JSONB NOT NULL
data BYTEA NOT NULL
)
""")
conn.execute("""
Expand All @@ -75,14 +74,14 @@ def _ensure_table(self) -> None:

async def save(self, checkpoint: WorkflowCheckpoint) -> str:
"""Save a checkpoint to PostgreSQL."""
encoded = encode_checkpoint_value(checkpoint.to_dict())
data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301
async with await psycopg.AsyncConnection.connect(self._conninfo) as conn:
await conn.execute(
"""INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data)
VALUES (%s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""",
(checkpoint.checkpoint_id, checkpoint.workflow_name,
checkpoint.timestamp, Jsonb(encoded)),
checkpoint.timestamp, data),
)
return checkpoint.checkpoint_id

Expand All @@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint:
)).fetchone()
if row is None:
raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}")
decoded = decode_checkpoint_value(row["data"])
return WorkflowCheckpoint.from_dict(decoded)
return pickle.loads(row["data"]) # noqa: S301

async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
"""List all checkpoints for a workflow, ordered by timestamp."""
Expand All @@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi
"SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp",
(workflow_name,),
)).fetchall()
return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows]
return [pickle.loads(r["data"]) for r in rows] # noqa: S301

async def delete(self, checkpoint_id: str) -> bool:
"""Delete a checkpoint by ID."""
Expand All @@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
)).fetchone()
if row is None:
return None
return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"]))
return pickle.loads(row["data"]) # noqa: S301

async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]:
"""List checkpoint IDs for a workflow."""
Expand Down