diff --git a/examples/spanish/workflow_hitl_checkpoint_pg.py b/examples/spanish/workflow_hitl_checkpoint_pg.py index 62480c1..9cfa72f 100644 --- a/examples/spanish/workflow_hitl_checkpoint_pg.py +++ b/examples/spanish/workflow_hitl_checkpoint_pg.py @@ -10,11 +10,11 @@ import asyncio import os +import pickle # noqa: S403 from dataclasses import dataclass from typing import Any import psycopg -from psycopg.types.json import Jsonb from agent_framework import ( Agent, AgentExecutor, @@ -29,10 +29,6 @@ response_handler, ) from agent_framework import WorkflowCheckpoint - -# Importación privada — aún no hay API pública para la codificación de checkpoints. -# Ver: https://github.com/microsoft/agent-framework/issues/4428 -from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from agent_framework.exceptions import WorkflowCheckpointException from agent_framework.openai import OpenAIChatClient from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider @@ -51,7 +47,10 @@ class PostgresCheckpointStorage: """Almacenamiento de checkpoints respaldado por PostgreSQL. Guarda checkpoints en una sola tabla con columnas para ID, nombre del workflow, - timestamp y los datos JSON codificados. SQL maneja la indexación y el filtrado. + timestamp y los datos serializados con pickle. SQL maneja la indexación y el filtrado. + + ADVERTENCIA DE SEGURIDAD: Los checkpoints usan pickle para la serialización. + Solo carga checkpoints de fuentes confiables. """ def __init__(self, conninfo: str) -> None: @@ -65,7 +64,7 @@ def _ensure_table(self) -> None: id TEXT PRIMARY KEY, workflow_name TEXT NOT NULL, timestamp TEXT NOT NULL, - data JSONB NOT NULL + data BYTEA NOT NULL ) """) conn.execute(""" @@ -75,14 +74,14 @@ def _ensure_table(self) -> None: async def save(self, checkpoint: WorkflowCheckpoint) -> str: """Guarda un checkpoint en PostgreSQL.""" - encoded = encode_checkpoint_value(checkpoint.to_dict()) + data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301 async with await psycopg.AsyncConnection.connect(self._conninfo) as conn: await conn.execute( """INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data) VALUES (%s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""", (checkpoint.checkpoint_id, checkpoint.workflow_name, - checkpoint.timestamp, Jsonb(encoded)), + checkpoint.timestamp, data), ) return checkpoint.checkpoint_id @@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: )).fetchone() if row is None: raise WorkflowCheckpointException(f"No se encontró checkpoint con ID {checkpoint_id}") - decoded = decode_checkpoint_value(row["data"]) - return WorkflowCheckpoint.from_dict(decoded) + return pickle.loads(row["data"]) # noqa: S301 async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: """Lista todos los checkpoints de un workflow, ordenados por timestamp.""" @@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi "SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp", (workflow_name,), )).fetchall() - return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows] + return [pickle.loads(r["data"]) for r in rows] # noqa: S301 async def delete(self, checkpoint_id: str) -> bool: """Elimina un checkpoint por ID.""" @@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: )).fetchone() if row is None: return None - return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"])) + return pickle.loads(row["data"]) # noqa: S301 async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]: """Lista los IDs de checkpoints de un workflow.""" diff --git a/examples/workflow_hitl_checkpoint_pg.py b/examples/workflow_hitl_checkpoint_pg.py index a13f380..6dc792d 100644 --- a/examples/workflow_hitl_checkpoint_pg.py +++ b/examples/workflow_hitl_checkpoint_pg.py @@ -10,11 +10,11 @@ import asyncio import os +import pickle # noqa: S403 from dataclasses import dataclass from typing import Any import psycopg -from psycopg.types.json import Jsonb from agent_framework import ( Agent, AgentExecutor, @@ -29,10 +29,6 @@ response_handler, ) from agent_framework import WorkflowCheckpoint - -# Private import — no public API for checkpoint encoding yet. -# See: https://github.com/microsoft/agent-framework/issues/4428 -from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from agent_framework.exceptions import WorkflowCheckpointException from agent_framework.openai import OpenAIChatClient from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider @@ -51,7 +47,10 @@ class PostgresCheckpointStorage: """PostgreSQL-backed checkpoint storage. Stores checkpoints in a single table with columns for ID, workflow name, - timestamp, and the encoded JSON data. SQL handles indexing and filtering. + timestamp, and the pickled checkpoint data. SQL handles indexing and filtering. + + SECURITY WARNING: Checkpoints use pickle for serialization. Only load + checkpoints from trusted sources. """ def __init__(self, conninfo: str) -> None: @@ -65,7 +64,7 @@ def _ensure_table(self) -> None: id TEXT PRIMARY KEY, workflow_name TEXT NOT NULL, timestamp TEXT NOT NULL, - data JSONB NOT NULL + data BYTEA NOT NULL ) """) conn.execute(""" @@ -75,14 +74,14 @@ def _ensure_table(self) -> None: async def save(self, checkpoint: WorkflowCheckpoint) -> str: """Save a checkpoint to PostgreSQL.""" - encoded = encode_checkpoint_value(checkpoint.to_dict()) + data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301 async with await psycopg.AsyncConnection.connect(self._conninfo) as conn: await conn.execute( """INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data) VALUES (%s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""", (checkpoint.checkpoint_id, checkpoint.workflow_name, - checkpoint.timestamp, Jsonb(encoded)), + checkpoint.timestamp, data), ) return checkpoint.checkpoint_id @@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint: )).fetchone() if row is None: raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}") - decoded = decode_checkpoint_value(row["data"]) - return WorkflowCheckpoint.from_dict(decoded) + return pickle.loads(row["data"]) # noqa: S301 async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: """List all checkpoints for a workflow, ordered by timestamp.""" @@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi "SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp", (workflow_name,), )).fetchall() - return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows] + return [pickle.loads(r["data"]) for r in rows] # noqa: S301 async def delete(self, checkpoint_id: str) -> bool: """Delete a checkpoint by ID.""" @@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: )).fetchone() if row is None: return None - return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"])) + return pickle.loads(row["data"]) # noqa: S301 async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]: """List checkpoint IDs for a workflow."""