Conversation
📝 WalkthroughWalkthroughThis pull request introduces a comprehensive refactoring of the model training and attack toolkit to support multiple model types and generalize configuration handling. The changes rename legacy configuration classes to use ClavaDDPM prefixes (e.g., DiffusionConfig → ClavaDDPMDiffusionConfig), introduce a new ModelType enum to support both TABDDPM and CTGAN model workflows, and add type-specific training result dataclasses (TabDDPMTrainingResult, CTGANTrainingResult). Configuration keys are standardized from hardcoded "tabddpm_training_config_path" to a generic "training_config_path". New ensemble attack scripts are added for CTGAN workflows, data handling is generalized to drop all ID columns, and type hints are updated across the codebase to reflect the new structures. Multiple example scripts and test files are updated to use the refactored naming conventions and new branching logic. Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 13
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
196-224:⚠️ Potential issue | 🔴 CriticalPersist full
TrainingResultobjects instead of onlysynthetic_data.Line 223 and Line 350 currently store just
pd.DataFrame, which discards model/config/save-dir metadata and breaks the new result-object contract used by integration tests.🐛 Proposed fix
- attack_data["fine_tuned_results"].append(train_result.synthetic_data) + attack_data["fine_tuned_results"].append(train_result) ... - attack_data["trained_results"].append(train_result.synthetic_data) + attack_data["trained_results"].append(train_result)Also applies to: 325-350
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py` around lines 196 - 224, The code currently appends only train_result.synthetic_data to attack_data["fine_tuned_results"], losing the TrainingResult metadata; update the two places where you append (the block using train_result from fine_tune_tabddpm_and_synthesize and the block using train_result from train_or_fine_tune_ctgan) to append the entire TrainingResult object (train_result) instead of train_result.synthetic_data, and adjust any downstream consumers/tests to expect TrainingResult entries (preserving model/config/save_dir fields from initial_model_training_results.models and the called functions fine_tune_tabddpm_and_synthesize / train_or_fine_tune_ctgan).src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
191-200:⚠️ Potential issue | 🟡 MinorReturn type annotation mismatch.
The function
fine_tune_tabddpm_and_synthesizehas return typeTrainingResult(line 200), but it actually returns aTabDDPMTrainingResult(line 249). This should be consistent withtrain_tabddpm_and_synthesizewhich correctly declaresTabDDPMTrainingResultas its return type.🛠️ Proposed fix
def fine_tune_tabddpm_and_synthesize( trained_models: dict[Relation, ClavaDDPMModelArtifacts], fine_tune_set: pd.DataFrame, configs: ClavaDDPMTrainingConfig, save_dir: Path, fine_tuning_diffusion_iterations: int = 100, fine_tuning_classifier_iterations: int = 10, synthesize: bool = True, number_of_points_to_synthesize: int = 20000, -) -> TrainingResult: +) -> TabDDPMTrainingResult:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py` around lines 191 - 200, The return type annotation for fine_tune_tabddpm_and_synthesize is incorrect: it declares TrainingResult but actually returns a TabDDPMTrainingResult; update the function signature of fine_tune_tabddpm_and_synthesize to return TabDDPMTrainingResult (matching train_tabddpm_and_synthesize) and ensure any imports or type references for TabDDPMTrainingResult are present so the annotation resolves.
🧹 Nitpick comments (7)
examples/gan/synthesize.py (1)
37-42: Consider more robust config access fordata_path.The current check
config.training.data_path is not Nonedoesn't handle:
- Missing
data_pathkey in config (would raiseConfigAttributeErrorbefore the comparison).- Empty string
""(would pass the check, resulting in_synthetic.csvfilename).Using a truthy check with safe attribute access would be more defensive:
♻️ Proposed fix for robustness
- if config.training.data_path is not None: - dataset_name = Path(config.training.data_path).stem - else: - dataset_name = get_table_name(config.base_data_dir) + data_path = getattr(config.training, "data_path", None) + if data_path: + dataset_name = Path(data_path).stem + else: + dataset_name = get_table_name(config.base_data_dir)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gan/synthesize.py` around lines 37 - 42, The code that sets dataset_name uses a fragile check (config.training.data_path is not None) which can raise if data_path is missing or produce wrong names for empty strings; update the logic to safely access and truthily validate data_path (e.g., use getattr(config.training, "data_path", None) or vars-accessor and check if isinstance(data_path, str) and data_path.strip() != "") and only then set dataset_name = Path(data_path).stem; otherwise call get_table_name(config.base_data_dir). Ensure the new condition covers missing attribute and empty/whitespace strings before constructing synthetic_data_file.examples/gan/ensemble_attack/compute_attack_success.py (1)
21-24: Prefer config-driventarget_idsover hardcoded[0].Line [23] hardcodes a sentinel ID. Even if ignored today, this can make run provenance unclear and brittle for future multi-target support.
Suggested refactor
- compute_attack_success_for_given_targets( + target_ids_cfg = config.ensemble_attack.target_model.get("target_ids") + target_ids = list(target_ids_cfg) if target_ids_cfg else [0] + + compute_attack_success_for_given_targets( target_model_config=config.ensemble_attack.target_model, - # TODO: refactor this to work better outside of the challenge context (i.e. no target ID) - # No target ID needed for CTGAN, but it needs at least one element in this array. The value does not matter. - target_ids=[0], + target_ids=target_ids, experiment_directory=Path(config.results_dir), metaclassifier_model_name=config.ensemble_attack.metaclassifier.meta_classifier_model_name, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gan/ensemble_attack/compute_attack_success.py` around lines 21 - 24, Replace the hardcoded sentinel target_ids=[0] with a config-driven value: read target_ids from the existing config object (e.g., config.target_ids or config.get("target_ids")) and pass that into the call, validating that it is a non-empty list and falling back to a single-element list (e.g., [0]) only if the config provides nothing valid; update the code around the target_ids parameter in compute_attack_success.py where target_ids is passed and keep experiment_directory=Path(config.results_dir) unchanged.examples/gan/ensemble_attack/config.yaml (1)
29-33: TODO noted: Pipeline flags need proper testing.The TODO comment on line 30 indicates that the pipeline control flags (
run_data_processing,run_shadow_model_training,run_metaclassifier_training) haven't been fully tested. Ensure these are validated before merging to main or track this in the follow-up work mentioned in the PR description.Would you like me to open an issue to track proper testing of these pipeline flags?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gan/ensemble_attack/config.yaml` around lines 29 - 33, The TODO notes that the pipeline control flags under the pipeline config (run_data_processing, run_shadow_model_training, run_metaclassifier_training) are untested; add validation and tests: implement runtime validation in the pipeline bootstrap (check the pipeline config object) to assert supported flag combinations and emit clear warnings/errors, add unit/integration tests that exercise each flag and common combinations (e.g., data only, shadow only, metaclassifier only, and all false/true), update config.yaml defaults if needed, and add a CI job or test matrix to run these scenarios so the flags are covered before merging.examples/gan/ensemble_attack/train_attack_model.py (1)
101-109: Segmentation fault workaround warrants investigation.The dynamic import to avoid a segmentation fault is a red flag. This could indicate memory corruption, circular imports, or incompatible library interactions. The TODO should be prioritized to understand the root cause, as segfaults can be symptoms of deeper issues.
Would you like me to open an issue to track investigating this segmentation fault?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gan/ensemble_attack/train_attack_model.py` around lines 101 - 109, The dynamic import of "examples.ensemble_attack.run_metaclassifier_training" (meta_pipeline) to avoid a segmentation fault is a workaround, not a fix; investigate the root cause by reproducing the segfault with a minimal script that imports the module at top-level and running it under a native debugger (gdb) or memory checker (valgrind) to capture the crash stack, inspect for circular imports between train_attack_model.py and examples.ensemble_attack.run_metaclassifier_training, and check for problematic native extensions or incompatible package versions; after identifying the culprit, either refactor to remove the circular dependency (or fix the native extension/version), add a failing test that imports the module normally, and replace the dynamic import/meta_pipeline.run_metaclassifier_training usage with the normal top-level import once fixed, and open a tracked issue documenting reproduction steps, stack trace, and environment (Python version, OS, dependent library versions).examples/gan/ensemble_attack/test_attack_model.py (1)
10-12: Avoid pytest-style naming for an executable Hydra entrypoint.Using a
test_*filename andtest_*function for this script triggers pytest collection/execution. The repository has notestpathsornorecursedirsconfiguration to exclude theexamples/directory, so pytest will discover and attempt to run this file by default.♻️ Suggested change
`@hydra.main`(config_path="./", config_name="config", version_base=None) -def test_attack_model(config: DictConfig) -> None: +def main(config: DictConfig) -> None: """Main function to test the attack model.""" log( INFO, f"Testing attack model against synthetic data at {config.ensemble_attack.target_model.target_synthetic_data_path}...", ) run_metaclassifier_testing(config.ensemble_attack) if __name__ == "__main__": - test_attack_model() + main()Also applies to: 20-21
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gan/ensemble_attack/test_attack_model.py` around lines 10 - 12, The script defines a Hydra entrypoint named test_attack_model and lives in a file whose name starts with test_, causing pytest to discover and run it; rename the entrypoint function (e.g., to run_attack_model or main_attack_model) and rename the file to not start with test_ (or otherwise avoid a test_ prefix) so pytest won't collect it, and update the `@hydra.main-decorated` function name references (test_attack_model) and any other similar functions on lines ~20-21 to the new non-test_* names to preserve Hydra behavior while preventing pytest collection.tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
19-19: Rename the test to reflect the generalized API.Line 19 still uses
test_save_additional_tabddpm_config, but the test now validatessave_additional_training_config. A generic test name will be clearer for multi-model support.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/attacks/ensemble/test_shadow_model_utils.py` at line 19, Rename the test function test_save_additional_tabddpm_config to a generic name like test_save_additional_training_config to match the updated API (save_additional_training_config); update the function definition and any references to it in the test module (e.g., fixtures or calls) and ensure the test still invokes save_additional_training_config instead of the old TabDDPM-specific helper so the name and behavior are consistent.tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
71-82: Preferisinstanceover exacttype(...) is ...assertions.Line 72 and Line 118 (and similarly DataFrame checks) are stricter than needed and can break with harmless subclassing/wrappers.
💡 Suggested change
- assert type(result) is TabDDPMTrainingResult + assert isinstance(result, TabDDPMTrainingResult) ... - assert type(result.synthetic_data) is pd.DataFrame + assert isinstance(result.synthetic_data, pd.DataFrame)Also applies to: 117-127
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/attacks/ensemble/test_shadow_model_training.py` around lines 71 - 82, The tests use strict type equality checks (e.g., "type(result) is TabDDPMTrainingResult" and "type(result.synthetic_data) is pd.DataFrame") which can fail for valid subclasses or wrappers; change these to use isinstance checks instead: replace exact type comparisons with isinstance(result, TabDDPMTrainingResult) and isinstance(result.synthetic_data, pd.DataFrame) and keep the existing attribute and length assertions (e.g., for synthetic_data length == 5) to preserve test intent; update occurrences around the shadow model training assertions referencing TabDDPMTrainingResult and synthetic_data.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ensemble_attack/compute_attack_success.py`:
- Around line 79-84: Looping over multiple target IDs currently reuses the same
paths when target_model_config lacks "target_model_id", causing duplicated
scores; fix by setting target_model_config.target_model_id = target_id whenever
the key is missing (not only when present) and then regenerate any dependent
paths (attack_probabilities_result_path and challenge_label_path) from
target_model_config (or include target_id in their filenames) so each target
gets distinct prediction/label files; update the code around the existing
conditional that touches target_model_config and ensure these regenerated paths
are used downstream.
- Around line 46-47: The CSV label-loading assumes a single column but uses
to_numpy().squeeze(), which can silently produce wrong shapes if multiple
columns (e.g., index + label) exist; replace the squeeze approach in the block
that sets test_target from challenge_label_path by reading into a DataFrame,
assert df.shape[1] == 1 (or raise a clear ValueError mentioning
challenge_label_path), and set test_target = df.iloc[:, 0].to_numpy(); apply the
identical change in the same pattern in
examples/ensemble_attack/test_attack_model.py where labels are loaded.
In `@examples/ensemble_attack/README.md`:
- Line 8: The README references the wrong config key name for the population
output path; update the documentation in the sentence that describes running
run_attack.py (and mentions pipeline.run_data_processing and
configs/experiment_config.yaml) to use data_paths.population_path instead of
data_paths.population_data so it matches the actual config keys (leave
data_paths.midst_data_path and data_paths.processed_attack_data_path unchanged).
In `@examples/ensemble_attack/real_data_collection.py`:
- Around line 225-226: The function mixes a dynamic id_columns list (and
df_population_no_id) with later hardcoded "trans_id" and "account_id"
references; update the code to consistently derive ID column names from
id_columns (e.g., select the transaction and account IDs by matching suffixes
like endswith("_id") or via a small helper that returns the correct column name)
and replace any direct uses of the string literals "trans_id" and "account_id"
with those derived names so all ID handling (creation, selection, and dropping)
uses the same computed variables (id_columns, df_population_no_id or a new
get_id_column helper).
In `@examples/ensemble_attack/test_attack_model.py`:
- Around line 346-349: The CSV branch reading challenge labels should validate
the CSV has exactly one column and produce a 1D numpy array; update the elif
handling of challenge_label_path to read into a DataFrame, check df.shape[1] ==
1 and raise ValueError if not, then set test_target = df.iloc[:, 0].to_numpy()
(ensuring a 1D array) so downstream code expecting a 1D label vector won’t
break; reference the variables challenge_label_path and test_target to locate
the change.
- Around line 268-277: The code only checks (processed_attack_data_path /
challenge_data_file_name).exists() before skipping collection but then
unconditionally calls load_dataframe for both challenge_data_file_name and
"master_challenge_train.csv", which can fail if the master CSV is missing;
change the condition to verify both files exist (e.g., check existence of
(processed_attack_data_path / challenge_data_file_name) AND
(processed_attack_data_path / "master_challenge_train.csv") ) before setting
df_challenge_experiment and df_master_train via load_dataframe, otherwise fall
through to perform the data collection path or raise a clear error.
In `@examples/gan/ensemble_attack/make_challenge_dataset.py`:
- Line 27: The current sampling line may raise ValueError when the untrained
pool is smaller than the training set; modify the logic around untrained_data =
real_data[~real_data[id_column].isin(training_data[id_column])].sample(len(training_data))
to first compute pool_size =
len(real_data[~real_data[id_column].isin(training_data[id_column])]) and then:
if pool_size >= len(training_data) keep sampling without replacement, else
either call sample(len(training_data), replace=True) or handle it explicitly
(raise a clearer error or sample only pool_size and log a warning) so sampling
never throws; update references to untrained_data, real_data, id_column, and
training_data accordingly.
In `@examples/gan/ensemble_attack/README.md`:
- Around line 64-65: Fix the typo in the README note: change the config key
reference from ensemble_attack.shadow_trainig.model_name to
ensemble_attack.shadow_training.model_name so the documentation matches the
actual config key used (update the text in the README where the mistaken key
appears).
- Line 81: Update the broken markdown anchor in
examples/gan/ensemble_attack/README.md: replace the link target
`#2-training-the-target-model` with the actual section anchor that matches the
heading "Generating target synthetic data to be tested" (e.g., use the correct
slugified anchor for that heading), so the [step 2] jump link navigates to the
"Generating target synthetic data to be tested" section.
- Around line 3-4: The relative link [Ensemble Attack](examples/ensemble_attack)
in the README line starting "On this example, we demonstrate how to run the
[Ensemble Attack]" resolves incorrectly; update that link target to the correct
relative path from examples/gan/ensemble_attack/README.md (for example change to
../../ensemble_attack or to an absolute /examples/ensemble_attack) so the
[Ensemble Attack] reference points to the actual docs.
In `@examples/gan/train.py`:
- Around line 27-35: The code sets dataset_name and real_data differently
depending on config.training.data_path but later always reads domain/metadata
from config.base_data_dir, causing domain-file coupling; change the logic so the
metadata path is chosen consistently: when config.training.data_path is
provided, set dataset_name = Path(config.training.data_path).stem and derive
metadata_dir = Path(config.training.data_path).parent, otherwise use table_name
and metadata_dir = Path(config.base_data_dir); then load domain metadata from
metadata_dir / f"{dataset_name}_domain.json" (or the existing metadata filename)
so that real_data and domain metadata come from the same directory.
- Around line 37-42: Before calling real_data.sample, validate
config.training.sample_size: ensure it's an int greater than 0 and <=
len(real_data); if not, raise a ValueError with a clear message referencing
sample_size and the available row count. Update the block that currently uses
config.training.sample_size and real_data.sample to perform this check (use
config.training.sample_size, real_data, and dataset_name to build the error
text) so invalid values fail fast with an informative message rather than
letting pandas.sample raise an opaque exception.
In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py`:
- Line 13: The file shadow_model_utils.py imports helper functions from examples
(from examples.gan.utils import get_single_table_svd_metadata, get_table_name)
which breaks module boundaries; move or reimplement those helpers inside the src
package (e.g., create a new util module under src/midst_toolkit/, e.g.,
midst_toolkit.utils or midst_toolkit.helpers) and update shadow_model_utils.py
to import get_single_table_svd_metadata and get_table_name from that new
internal module; ensure the new functions keep identical signatures and behavior
so usages in functions/classes inside shadow_model_utils.py continue to work
without changing callers.
---
Outside diff comments:
In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py`:
- Around line 196-224: The code currently appends only
train_result.synthetic_data to attack_data["fine_tuned_results"], losing the
TrainingResult metadata; update the two places where you append (the block using
train_result from fine_tune_tabddpm_and_synthesize and the block using
train_result from train_or_fine_tune_ctgan) to append the entire TrainingResult
object (train_result) instead of train_result.synthetic_data, and adjust any
downstream consumers/tests to expect TrainingResult entries (preserving
model/config/save_dir fields from initial_model_training_results.models and the
called functions fine_tune_tabddpm_and_synthesize / train_or_fine_tune_ctgan).
In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py`:
- Around line 191-200: The return type annotation for
fine_tune_tabddpm_and_synthesize is incorrect: it declares TrainingResult but
actually returns a TabDDPMTrainingResult; update the function signature of
fine_tune_tabddpm_and_synthesize to return TabDDPMTrainingResult (matching
train_tabddpm_and_synthesize) and ensure any imports or type references for
TabDDPMTrainingResult are present so the annotation resolves.
---
Nitpick comments:
In `@examples/gan/ensemble_attack/compute_attack_success.py`:
- Around line 21-24: Replace the hardcoded sentinel target_ids=[0] with a
config-driven value: read target_ids from the existing config object (e.g.,
config.target_ids or config.get("target_ids")) and pass that into the call,
validating that it is a non-empty list and falling back to a single-element list
(e.g., [0]) only if the config provides nothing valid; update the code around
the target_ids parameter in compute_attack_success.py where target_ids is passed
and keep experiment_directory=Path(config.results_dir) unchanged.
In `@examples/gan/ensemble_attack/config.yaml`:
- Around line 29-33: The TODO notes that the pipeline control flags under the
pipeline config (run_data_processing, run_shadow_model_training,
run_metaclassifier_training) are untested; add validation and tests: implement
runtime validation in the pipeline bootstrap (check the pipeline config object)
to assert supported flag combinations and emit clear warnings/errors, add
unit/integration tests that exercise each flag and common combinations (e.g.,
data only, shadow only, metaclassifier only, and all false/true), update
config.yaml defaults if needed, and add a CI job or test matrix to run these
scenarios so the flags are covered before merging.
In `@examples/gan/ensemble_attack/test_attack_model.py`:
- Around line 10-12: The script defines a Hydra entrypoint named
test_attack_model and lives in a file whose name starts with test_, causing
pytest to discover and run it; rename the entrypoint function (e.g., to
run_attack_model or main_attack_model) and rename the file to not start with
test_ (or otherwise avoid a test_ prefix) so pytest won't collect it, and update
the `@hydra.main-decorated` function name references (test_attack_model) and any
other similar functions on lines ~20-21 to the new non-test_* names to preserve
Hydra behavior while preventing pytest collection.
In `@examples/gan/ensemble_attack/train_attack_model.py`:
- Around line 101-109: The dynamic import of
"examples.ensemble_attack.run_metaclassifier_training" (meta_pipeline) to avoid
a segmentation fault is a workaround, not a fix; investigate the root cause by
reproducing the segfault with a minimal script that imports the module at
top-level and running it under a native debugger (gdb) or memory checker
(valgrind) to capture the crash stack, inspect for circular imports between
train_attack_model.py and examples.ensemble_attack.run_metaclassifier_training,
and check for problematic native extensions or incompatible package versions;
after identifying the culprit, either refactor to remove the circular dependency
(or fix the native extension/version), add a failing test that imports the
module normally, and replace the dynamic
import/meta_pipeline.run_metaclassifier_training usage with the normal top-level
import once fixed, and open a tracked issue documenting reproduction steps,
stack trace, and environment (Python version, OS, dependent library versions).
In `@examples/gan/synthesize.py`:
- Around line 37-42: The code that sets dataset_name uses a fragile check
(config.training.data_path is not None) which can raise if data_path is missing
or produce wrong names for empty strings; update the logic to safely access and
truthily validate data_path (e.g., use getattr(config.training, "data_path",
None) or vars-accessor and check if isinstance(data_path, str) and
data_path.strip() != "") and only then set dataset_name = Path(data_path).stem;
otherwise call get_table_name(config.base_data_dir). Ensure the new condition
covers missing attribute and empty/whitespace strings before constructing
synthetic_data_file.
In `@tests/integration/attacks/ensemble/test_shadow_model_training.py`:
- Around line 71-82: The tests use strict type equality checks (e.g.,
"type(result) is TabDDPMTrainingResult" and "type(result.synthetic_data) is
pd.DataFrame") which can fail for valid subclasses or wrappers; change these to
use isinstance checks instead: replace exact type comparisons with
isinstance(result, TabDDPMTrainingResult) and isinstance(result.synthetic_data,
pd.DataFrame) and keep the existing attribute and length assertions (e.g., for
synthetic_data length == 5) to preserve test intent; update occurrences around
the shadow model training assertions referencing TabDDPMTrainingResult and
synthetic_data.
In `@tests/unit/attacks/ensemble/test_shadow_model_utils.py`:
- Line 19: Rename the test function test_save_additional_tabddpm_config to a
generic name like test_save_additional_training_config to match the updated API
(save_additional_training_config); update the function definition and any
references to it in the test module (e.g., fixtures or calls) and ensure the
test still invokes save_additional_training_config instead of the old
TabDDPM-specific helper so the name and behavior are consistent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 66ffa678-a318-4cab-a033-9f92592645e9
📒 Files selected for processing (39)
examples/ensemble_attack/README.mdexamples/ensemble_attack/compute_attack_success.pyexamples/ensemble_attack/configs/experiment_config.yamlexamples/ensemble_attack/configs/original_attack_config.yamlexamples/ensemble_attack/real_data_collection.pyexamples/ensemble_attack/run_shadow_model_training.pyexamples/ensemble_attack/test_attack_model.pyexamples/gan/README.mdexamples/gan/ensemble_attack/README.mdexamples/gan/ensemble_attack/compute_attack_success.pyexamples/gan/ensemble_attack/config.yamlexamples/gan/ensemble_attack/make_challenge_dataset.pyexamples/gan/ensemble_attack/test_attack_model.pyexamples/gan/ensemble_attack/train_attack_model.pyexamples/gan/synthesize.pyexamples/gan/train.pyexamples/synthesizing/multi_table/README.mdexamples/synthesizing/multi_table/run_synthesizing.pyexamples/synthesizing/single_table/README.mdexamples/synthesizing/single_table/run_synthesizing.pyexamples/training/multi_table/README.mdexamples/training/multi_table/run_training.pyexamples/training/single_table/README.mdexamples/training/single_table/run_training.pysrc/midst_toolkit/attacks/ensemble/blending.pysrc/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.pysrc/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.pysrc/midst_toolkit/attacks/ensemble/shadow_model_utils.pysrc/midst_toolkit/common/config.pysrc/midst_toolkit/models/clavaddpm/clustering.pysrc/midst_toolkit/models/clavaddpm/enumerations.pysrc/midst_toolkit/models/clavaddpm/synthesizer.pysrc/midst_toolkit/models/clavaddpm/train.pytests/integration/attacks/ensemble/configs/shadow_training_config.yamltests/integration/attacks/ensemble/test_shadow_model_training.pytests/integration/models/clavaddpm/test_model.pytests/integration/models/clavaddpm/test_synthesizer.pytests/unit/attacks/ensemble/configs/shadow_training_config.yamltests/unit/attacks/ensemble/test_shadow_model_utils.py
| elif challenge_label_path.suffix == ".csv": | ||
| test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the full function context around lines 46-47
head -80 examples/ensemble_attack/compute_attack_success.py | tail -50Repository: VectorInstitute/midst-toolkit
Length of output: 2560
🏁 Script executed:
# Search for how challenge labels are created/saved to understand the CSV schema
rg -n "challenge_label|to_csv|\.csv" --type=py -B2 -A2Repository: VectorInstitute/midst-toolkit
Length of output: 50385
🏁 Script executed:
# Check if there are any tests or examples showing the format of challenge labels
fd -t f -e csv | head -20Repository: VectorInstitute/midst-toolkit
Length of output: 1524
🏁 Script executed:
# Search for any documentation or usage of challenge_label_path
rg -n "challenge_label_path" --type=py -B3 -A3Repository: VectorInstitute/midst-toolkit
Length of output: 7346
🏁 Script executed:
# Examine sample challenge_label.csv files
for csv_file in $(find tests -name "challenge_label.csv"); do
echo "=== $csv_file ==="
head -5 "$csv_file"
doneRepository: VectorInstitute/midst-toolkit
Length of output: 780
🏁 Script executed:
# Check how downstream code uses test_target in metric computation
rg -A10 "def compute_attack_success_for_given_targets" examples/ensemble_attack/compute_attack_success.pyRepository: VectorInstitute/midst-toolkit
Length of output: 618
🏁 Script executed:
# Check how downstream metrics use test_target - look for accuracy, precision, etc. calls
rg -B5 -A10 "def compute_attack_success\(" examples/ensemble_attack/compute_attack_success.pyRepository: VectorInstitute/midst-toolkit
Length of output: 55
🏁 Script executed:
# Search for metric computation functions that use test_target/targets
rg -n "accuracy|precision|roc_auc|metric" examples/ensemble_attack/compute_attack_success.py -A3Repository: VectorInstitute/midst-toolkit
Length of output: 1166
🏁 Script executed:
# Check the full compute_attack_success function to see how targets/predictions are used
sed -n '63,110p' examples/ensemble_attack/compute_attack_success.pyRepository: VectorInstitute/midst-toolkit
Length of output: 2595
🏁 Script executed:
# Look for what happens with tpr_at_fpr - what function computes it from targets and predictions
rg -B5 -A10 "tpr_at_fpr" examples/ensemble_attack/compute_attack_success.pyRepository: VectorInstitute/midst-toolkit
Length of output: 991
🏁 Script executed:
# Check imports and metric computation library
head -20 examples/ensemble_attack/compute_attack_success.pyRepository: VectorInstitute/midst-toolkit
Length of output: 647
🏁 Script executed:
# Check how challenge labels are actually created/saved in the codebase
rg -B5 -A5 "\.save\(.*challenge.*label|to_csv.*challenge" examples/ --type=pyRepository: VectorInstitute/midst-toolkit
Length of output: 1294
🏁 Script executed:
# Check if TprAtFpr.get_tpr_at_fpr expects 1D array and how it handles 2D input
find . -name "*.py" -type f -exec grep -l "class TprAtFpr" {} \; | head -1 | xargs catRepository: VectorInstitute/midst-toolkit
Length of output: 13197
🏁 Script executed:
# Verify: can squeeze() fail to produce 1D on multi-column CSV?
python3 << 'EOF'
import pandas as pd
import numpy as np
import tempfile
import os
# Simulate what happens when someone saves with index=True (pandas default)
df = pd.DataFrame({'label': [0, 1, 0, 1]})
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
fname = f.name
df.to_csv(fname, index=True) # Default pandas behavior
# Now load it back (what the code does)
loaded = pd.read_csv(fname)
print(f"Loaded columns: {loaded.columns.tolist()}")
print(f"Loaded shape: {loaded.shape}")
arr = loaded.to_numpy().squeeze()
print(f"After squeeze() - shape: {arr.shape}, ndim: {arr.ndim}")
print(f"len(arr): {len(arr)}")
print(f"Array:\n{arr}")
# Would metric computation break?
try:
# Simulate what metrics expect (1D binary array)
from sklearn.metrics import roc_auc_score
predictions = np.array([0.2, 0.8, 0.3, 0.9])
roc_auc_score(arr, predictions)
print("\nroc_auc_score succeeded on 2D array (shouldn't happen)")
except Exception as e:
print(f"\nroc_auc_score failed with 2D array: {type(e).__name__}: {e}")
os.unlink(fname)
EOFRepository: VectorInstitute/midst-toolkit
Length of output: 191
🏁 Script executed:
# Verify actual CSV saving patterns to confirm index=False is used
rg "to_csv\(" examples/ensemble_attack/ --type=py -B1 -A1 | head -40Repository: VectorInstitute/midst-toolkit
Length of output: 394
🏁 Script executed:
# Check if there are any existing unit tests for this function
rg -l "load_target_challenge_labels_and_probabilities|compute_attack_success" tests/ --type=pyRepository: VectorInstitute/midst-toolkit
Length of output: 55
🏁 Script executed:
# Also check test_attack_model.py which has identical code
grep -n "\.npy\|\.csv" examples/ensemble_attack/test_attack_model.py | grep challenge_labelRepository: VectorInstitute/midst-toolkit
Length of output: 292
Add defensive column count validation for CSV-loaded labels.
The code loads challenge labels from CSV via pd.read_csv(challenge_label_path).to_numpy().squeeze(), which assumes a single-column file. If a CSV with multiple columns is provided (e.g., with an index column), squeeze() would not reduce it to 1D; the length assertion would still pass (checking only the first dimension), but downstream metric computation using roc_curve() would fail. While current usage saves CSVs with index=False and single columns, the code should be defensive against this edge case.
The proposed fix is sound: validate the column count and explicitly extract the first column via iloc[:, 0].to_numpy() instead of relying on squeeze().
🛠️ Suggested fix
- if challenge_label_path.suffix == ".npy":
+ suffix = challenge_label_path.suffix.lower()
+ if suffix == ".npy":
test_target = np.load(challenge_label_path).squeeze()
- elif challenge_label_path.suffix == ".csv":
- test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze()
+ elif suffix == ".csv":
+ labels_df = pd.read_csv(challenge_label_path)
+ if labels_df.shape[1] != 1:
+ raise ValueError(
+ f"Expected exactly 1 label column in {challenge_label_path}, got {labels_df.shape[1]}."
+ )
+ test_target = labels_df.iloc[:, 0].to_numpy()
else:
raise ValueError(f"Unsupported challenge label file type: {challenge_label_path}. Must be .npy or .csv.")Note: The identical pattern in examples/ensemble_attack/test_attack_model.py (lines 344–347) should also be updated to maintain consistency.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ensemble_attack/compute_attack_success.py` around lines 46 - 47, The
CSV label-loading assumes a single column but uses to_numpy().squeeze(), which
can silently produce wrong shapes if multiple columns (e.g., index + label)
exist; replace the squeeze approach in the block that sets test_target from
challenge_label_path by reading into a DataFrame, assert df.shape[1] == 1 (or
raise a clear ValueError mentioning challenge_label_path), and set test_target =
df.iloc[:, 0].to_numpy(); apply the identical change in the same pattern in
examples/ensemble_attack/test_attack_model.py where labels are loaded.
| # If there is a target model id in the config, override it with the current target id | ||
| if "target_model_id" in target_model_config: | ||
| # Override target model id in config as ``attack_probabilities_result_path`` and | ||
| # ``challenge_label_path`` are dependent on it and change in runtime. | ||
| target_model_config.target_model_id = target_id | ||
|
|
There was a problem hiding this comment.
Guard against repeated scoring when target_model_id is missing.
When multiple target IDs are passed but target_model_id is absent, the loop reuses the same paths for every ID and silently duplicates the same target’s predictions/labels in the aggregate metric.
🛠️ Suggested fix
def compute_attack_success_for_given_targets(
@@
) -> None:
@@
+ if "target_model_id" not in target_model_config and len(target_ids) > 1:
+ raise ValueError(
+ "target_model.target_model_id is required when evaluating multiple target IDs."
+ )
+
predictions = []
targets = []
for target_id in target_ids:
# If there is a target model id in the config, override it with the current target id
if "target_model_id" in target_model_config:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # If there is a target model id in the config, override it with the current target id | |
| if "target_model_id" in target_model_config: | |
| # Override target model id in config as ``attack_probabilities_result_path`` and | |
| # ``challenge_label_path`` are dependent on it and change in runtime. | |
| target_model_config.target_model_id = target_id | |
| def compute_attack_success_for_given_targets( | |
| ... | |
| ) -> None: | |
| if "target_model_id" not in target_model_config and len(target_ids) > 1: | |
| raise ValueError( | |
| "target_model.target_model_id is required when evaluating multiple target IDs." | |
| ) | |
| predictions = [] | |
| targets = [] | |
| for target_id in target_ids: | |
| # If there is a target model id in the config, override it with the current target id | |
| if "target_model_id" in target_model_config: | |
| # Override target model id in config as ``attack_probabilities_result_path`` and | |
| # ``challenge_label_path`` are dependent on it and change in runtime. | |
| target_model_config.target_model_id = target_id |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ensemble_attack/compute_attack_success.py` around lines 79 - 84,
Looping over multiple target IDs currently reuses the same paths when
target_model_config lacks "target_model_id", causing duplicated scores; fix by
setting target_model_config.target_model_id = target_id whenever the key is
missing (not only when present) and then regenerate any dependent paths
(attack_probabilities_result_path and challenge_label_path) from
target_model_config (or include target_id in their filenames) so each target
gets distinct prediction/label files; update the code around the existing
conditional that touches target_model_config and ensure these regenerated paths
are used downstream.
| Make sure directories and JSON files specified in `data_paths` and `data_processing_config` configurations in `examples/ensemble_attack/configs/experiment_config.yaml` exist. | ||
|
|
||
| To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in `config.yaml`. It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories. | ||
| To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in [`configs/experiment_config.yaml`](configs/experiment_config.yaml). It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories. |
There was a problem hiding this comment.
Fix the config key name in the docs (population_data → population_path).
The key shown here doesn’t match the config and can cause confusion while setting up paths.
📝 Suggested doc fix
-... and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories.
+... and populates `data_paths.population_path` and `data_paths.processed_attack_data_path` directories.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in [`configs/experiment_config.yaml`](configs/experiment_config.yaml). It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_data` and `data_paths.processed_attack_data_path` directories. | |
| To run the whole data processing pipeline, run `run_attack.py` and set `pipeline.run_data_processing` to `true` in [`configs/experiment_config.yaml`](configs/experiment_config.yaml). It reads data from `data_paths.midst_data_path` specified in config, and populates `data_paths.population_path` and `data_paths.processed_attack_data_path` directories. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ensemble_attack/README.md` at line 8, The README references the
wrong config key name for the population output path; update the documentation
in the sentence that describes running run_attack.py (and mentions
pipeline.run_data_processing and configs/experiment_config.yaml) to use
data_paths.population_path instead of data_paths.population_data so it matches
the actual config keys (leave data_paths.midst_data_path and
data_paths.processed_attack_data_path unchanged).
| id_columns = [c for c in df_population.columns if c.endswith("_id")] | ||
| df_population_no_id = df_population.drop(columns=id_columns) |
There was a problem hiding this comment.
Inconsistent ID column handling within the same function.
Lines 225-226 use a generalized approach to drop all columns ending with _id, but later lines (246, 249, 260) hardcode "trans_id" and "account_id". This inconsistency could cause issues if datasets have different ID column naming conventions.
Consider unifying the approach:
🛠️ Proposed fix to unify ID column handling
# Population data without the challenge points
- df_population_no_challenge = df_population[~df_population["trans_id"].isin(df_challenge["trans_id"])]
+ # Assuming trans_id is the primary key for deduplication
+ primary_id_col = "trans_id" # Could be parameterized if needed
+ df_population_no_challenge = df_population[~df_population[primary_id_col].isin(df_challenge[primary_id_col])]
save_dataframe(df_population_no_challenge, save_dir, "population_all_no_challenge.csv")
# Remove ids
- df_population_no_challenge_no_id = df_population_no_challenge.drop(columns=["trans_id", "account_id"])
+ id_columns_to_drop = [c for c in df_population_no_challenge.columns if c.endswith("_id")]
+ df_population_no_challenge_no_id = df_population_no_challenge.drop(columns=id_columns_to_drop)
save_dataframe(
df_population_no_challenge_no_id,
save_dir,
"population_all_no_challenge_no_id.csv",
)
# Population data with all the challenge points
df_population_with_challenge = pd.concat([df_population_no_challenge, df_challenge])
save_dataframe(df_population_with_challenge, save_dir, "population_all_with_challenge.csv")
# Remove ids
- df_population_with_challenge_no_id = df_population_with_challenge.drop(columns=["trans_id", "account_id"])
+ id_columns_to_drop = [c for c in df_population_with_challenge.columns if c.endswith("_id")]
+ df_population_with_challenge_no_id = df_population_with_challenge.drop(columns=id_columns_to_drop)
save_dataframe(Also applies to: 246-249, 260-260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ensemble_attack/real_data_collection.py` around lines 225 - 226, The
function mixes a dynamic id_columns list (and df_population_no_id) with later
hardcoded "trans_id" and "account_id" references; update the code to
consistently derive ID column names from id_columns (e.g., select the
transaction and account IDs by matching suffixes like endswith("_id") or via a
small helper that returns the correct column name) and replace any direct uses
of the string literals "trans_id" and "account_id" with those derived names so
all ID handling (creation, selection, and dropping) uses the same computed
variables (id_columns, df_population_no_id or a new get_id_column helper).
| if (processed_attack_data_path / challenge_data_file_name).exists(): | ||
| log(INFO, "Skipping data collection for testing phase.") | ||
| df_challenge_experiment = load_dataframe( | ||
| processed_attack_data_path, | ||
| challenge_data_file_name, | ||
| ) | ||
| df_master_train = load_dataframe( | ||
| processed_attack_data_path, | ||
| "master_challenge_train.csv", | ||
| ) |
There was a problem hiding this comment.
Require both cached CSVs before skipping data collection.
Line [268] checks only population_all_with_challenge_challenge_data.csv, but Lines [274-277] always load master_challenge_train.csv too. In a partial-cache state this fails at runtime.
Suggested fix
- if (processed_attack_data_path / challenge_data_file_name).exists():
+ challenge_data_file = processed_attack_data_path / challenge_data_file_name
+ master_train_file = processed_attack_data_path / "master_challenge_train.csv"
+
+ if challenge_data_file.exists() and master_train_file.exists():
log(INFO, "Skipping data collection for testing phase.")
df_challenge_experiment = load_dataframe(
processed_attack_data_path,
challenge_data_file_name,
)
df_master_train = load_dataframe(
processed_attack_data_path,
"master_challenge_train.csv",
)
else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ensemble_attack/test_attack_model.py` around lines 268 - 277, The
code only checks (processed_attack_data_path /
challenge_data_file_name).exists() before skipping collection but then
unconditionally calls load_dataframe for both challenge_data_file_name and
"master_challenge_train.csv", which can fail if the master CSV is missing;
change the condition to verify both files exist (e.g., check existence of
(processed_attack_data_path / challenge_data_file_name) AND
(processed_attack_data_path / "master_challenge_train.csv") ) before setting
df_challenge_experiment and df_master_train via load_dataframe, otherwise fall
through to perform the data collection path or raise a clear error.
| > In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_trainig.model_name` | ||
| > is what determines this attack will be run with the CTGAN model. |
There was a problem hiding this comment.
Correct the config key typo in Step 4 note.
Line [64] references ensemble_attack.shadow_trainig.model_name; it should be ensemble_attack.shadow_training.model_name.
Suggested fix
-> In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_trainig.model_name`
+> In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_training.model_name`📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| > In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_trainig.model_name` | |
| > is what determines this attack will be run with the CTGAN model. | |
| > In the [`config.yaml`](config.yaml) file, the attribute `ensemble_attack.shadow_training.model_name` | |
| > is what determines this attack will be run with the CTGAN model. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gan/ensemble_attack/README.md` around lines 64 - 65, Fix the typo in
the README note: change the config key reference from
ensemble_attack.shadow_trainig.model_name to
ensemble_attack.shadow_training.model_name so the documentation matches the
actual config key used (update the text in the README where the mistaken key
appears).
| ## 5. Testing the attack model | ||
|
|
||
| To test the attack model against the target model and synthetic data produced on | ||
| [step 2](#2-training-the-target-model), please run: |
There was a problem hiding this comment.
Update Step 2 anchor to match the actual heading.
Line [81] uses #2-training-the-target-model, but Section 2 is titled “Generating target synthetic data to be tested”, so the jump link is broken.
Suggested fix
-[step 2](`#2-training-the-target-model`)
+[step 2](`#2-generating-target-synthetic-data-to-be-tested`)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| [step 2](#2-training-the-target-model), please run: | |
| [step 2](`#2-generating-target-synthetic-data-to-be-tested`), please run: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gan/ensemble_attack/README.md` at line 81, Update the broken
markdown anchor in examples/gan/ensemble_attack/README.md: replace the link
target `#2-training-the-target-model` with the actual section anchor that
matches the heading "Generating target synthetic data to be tested" (e.g., use
the correct slugified anchor for that heading), so the [step 2] jump link
navigates to the "Generating target synthetic data to be tested" section.
| if config.training.data_path is None: | ||
| log(INFO, "Loading data with table name...") | ||
| dataset_name = table_name | ||
| real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv") | ||
|
|
||
| else: | ||
| log(INFO, f"Loading data from {config.training.data_path}...") | ||
| dataset_name = Path(config.training.data_path).stem | ||
| real_data = pd.read_csv(config.training.data_path) |
There was a problem hiding this comment.
Fix domain-file coupling when training.data_path is provided.
When data is loaded from config.training.data_path, the code still loads domain metadata from config.base_data_dir (Line 43). This can pair the wrong domain with the selected CSV and break non-challenge runs.
Proposed fix
table_name = get_table_name(config.base_data_dir)
if config.training.data_path is None:
log(INFO, "Loading data with table name...")
dataset_name = table_name
- real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv")
+ data_path = Path(config.base_data_dir) / f"{table_name}.csv"
+ domain_path = Path(config.base_data_dir) / f"{table_name}_domain.json"
+ real_data = pd.read_csv(data_path)
else:
log(INFO, f"Loading data from {config.training.data_path}...")
- dataset_name = Path(config.training.data_path).stem
- real_data = pd.read_csv(config.training.data_path)
+ data_path = Path(config.training.data_path)
+ dataset_name = data_path.stem
+ domain_path = data_path.with_name(f"{dataset_name}_domain.json")
+ real_data = pd.read_csv(data_path)
@@
- with open(Path(config.base_data_dir) / f"{table_name}_domain.json", "r") as f:
+ with open(domain_path, "r") as f:
domain_info = json.load(f)Also applies to: 43-44
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gan/train.py` around lines 27 - 35, The code sets dataset_name and
real_data differently depending on config.training.data_path but later always
reads domain/metadata from config.base_data_dir, causing domain-file coupling;
change the logic so the metadata path is chosen consistently: when
config.training.data_path is provided, set dataset_name =
Path(config.training.data_path).stem and derive metadata_dir =
Path(config.training.data_path).parent, otherwise use table_name and
metadata_dir = Path(config.base_data_dir); then load domain metadata from
metadata_dir / f"{dataset_name}_domain.json" (or the existing metadata filename)
so that real_data and domain metadata come from the same directory.
| if config.training.sample_size is not None: | ||
| log(INFO, f"Sampling {config.training.sample_size} rows from data...") | ||
| real_data = real_data.sample(n=config.training.sample_size) | ||
| Path(config.results_dir).mkdir(parents=True, exist_ok=True) | ||
| real_data.to_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv", index=False) | ||
|
|
There was a problem hiding this comment.
Validate sample_size bounds before sampling.
Add an explicit check so invalid values (non-positive or larger than row count) fail with a clear message instead of a runtime sampling exception.
Proposed fix
if config.training.sample_size is not None:
log(INFO, f"Sampling {config.training.sample_size} rows from data...")
+ if config.training.sample_size <= 0 or config.training.sample_size > len(real_data):
+ raise ValueError(
+ f"training.sample_size must be in [1, {len(real_data)}], got {config.training.sample_size}"
+ )
real_data = real_data.sample(n=config.training.sample_size)
Path(config.results_dir).mkdir(parents=True, exist_ok=True)
real_data.to_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv", index=False)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if config.training.sample_size is not None: | |
| log(INFO, f"Sampling {config.training.sample_size} rows from data...") | |
| real_data = real_data.sample(n=config.training.sample_size) | |
| Path(config.results_dir).mkdir(parents=True, exist_ok=True) | |
| real_data.to_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv", index=False) | |
| if config.training.sample_size is not None: | |
| log(INFO, f"Sampling {config.training.sample_size} rows from data...") | |
| if config.training.sample_size <= 0 or config.training.sample_size > len(real_data): | |
| raise ValueError( | |
| f"training.sample_size must be in [1, {len(real_data)}], got {config.training.sample_size}" | |
| ) | |
| real_data = real_data.sample(n=config.training.sample_size) | |
| Path(config.results_dir).mkdir(parents=True, exist_ok=True) | |
| real_data.to_csv(Path(config.results_dir) / f"{dataset_name}_sampled.csv", index=False) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gan/train.py` around lines 37 - 42, Before calling real_data.sample,
validate config.training.sample_size: ensure it's an int greater than 0 and <=
len(real_data); if not, raise a ValueError with a clear message referencing
sample_size and the available row count. Update the block that currently uses
config.training.sample_size and real_data.sample to perform this check (use
config.training.sample_size, real_data, and dataset_name to build the error
text) so invalid values fail fast with an informative message rather than
letting pandas.sample raise an opaque exception.
| import pandas as pd | ||
| from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped] | ||
|
|
||
| from examples.gan.utils import get_single_table_svd_metadata, get_table_name |
There was a problem hiding this comment.
Importing from examples into src violates module boundaries.
The import from examples.gan.utils import get_single_table_svd_metadata, get_table_name introduces a problematic dependency where library code (src/) depends on example code (examples/). This inverts the expected dependency direction and can cause issues:
- Examples should be standalone and depend on the library, not vice versa.
- This creates fragility if examples are reorganized or removed.
Consider moving get_single_table_svd_metadata and get_table_name to a utility module within src/midst_toolkit/.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py` at line 13, The
file shadow_model_utils.py imports helper functions from examples (from
examples.gan.utils import get_single_table_svd_metadata, get_table_name) which
breaks module boundaries; move or reimplement those helpers inside the src
package (e.g., create a new util module under src/midst_toolkit/, e.g.,
midst_toolkit.utils or midst_toolkit.helpers) and update shadow_model_utils.py
to import get_single_table_svd_metadata and get_table_name from that new
internal module; ensure the new functions keep identical signatures and behavior
so usages in functions/classes inside shadow_model_utils.py continue to work
without changing callers.
|
|
||
| # Challenge labels are the true membership labels for the challenge points. | ||
| test_target = pd.read_csv(challenge_label_path).to_numpy().squeeze() | ||
| if challenge_label_path.suffix == ".npy": |
There was a problem hiding this comment.
Is there a case where the inference result is saved as a .npy file?
There was a problem hiding this comment.
At one point in our code we save the labels as .npy so probably there is a bug here, but I don't want to remove the old code that loads a .csv in case there is some use case for that so I kept both.
I think the labels from the data aggregation from the tartan federer data on challenge points are a .csv but I'm not sure. At one point in our code we save those labels as an intermediary result when training the attack model and those are saved as .npy. I think it doesn't hurt to support both as it's not a lot of work.
sarakodeiri
left a comment
There was a problem hiding this comment.
Took a quick look and things looks neat in general. Added a few minor comments.
|
|
||
| # TODO: Add this to config or .json files | ||
| table_name = "trans" | ||
| table_name = config.table_name if "table_name" in config else DEFAULT_TABLE_NAME |
There was a problem hiding this comment.
Great change. Makes it easier for potentially moving the code to the core library.
|
|
||
| First, we need the data. Download it from this | ||
| [Google Drive link](https://drive.google.com/file/d/1B9z4vh51mH6ZMj5E0pJitqR8lid3EJKM/view?usp=drive_link), | ||
| extract the files and place them in a `/data/ensemble_attack` folder in within this folder |
There was a problem hiding this comment.
I think using both in and within is a typo
| trained_model: CTGANSynthesizer | None = None, | ||
| ) -> TrainingResult: | ||
| """ | ||
| Train or fine tune a CTGAN model on the provided dataset and optionally synthesize data. |
There was a problem hiding this comment.
it would be nice if the function name somehow reflected that data synthesization also happens within this function.
PR Type
Feature
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868h6nkzy
Making the Ensemble Attack run with CTGAN and adding the code required to make it run in an example.
The code is very rough, mostly just some if conditions and minor modifications to make it work with both TabDDPM/ClavaDDPM and CTGAN at the same time.
Also, there were lots of minor changes made in order for the code to work outside of the context of a challenge and make the code a little more flexible for other dataset types.
On follow up PRs, I will be working on moving parts of this code from the
examplesto the main lib folder and also make it more extensible to other model types.Tests Added
No tests have been added.