Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 73 additions & 64 deletions cursorless-talon/src/csv_overrides.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import csv
import typing
from collections import defaultdict
from collections.abc import Container
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, Iterable, Optional, TypedDict
from typing import Callable, Container, Iterable, Optional, Sequence, TypedDict

from talon import Context, Module, actions, app, fs, settings

Expand Down Expand Up @@ -49,6 +48,12 @@ class SpokenFormEntry:
spoken_forms: list[str]


class ResultsListEntry(TypedDict):
spoken: str
id: str
list: str


def csv_get_ctx():
return ctx

Expand All @@ -60,17 +65,17 @@ def csv_get_normalized_ctx():
def init_csv_and_watch_changes(
filename: str,
default_values: ListToSpokenForms,
handle_new_values: Optional[Callable[[list[SpokenFormEntry]], None]] = None,
handle_new_values: Optional[Callable[[Sequence[SpokenFormEntry]], None]] = None,
*,
extra_ignored_values: Optional[list[str]] = None,
extra_allowed_values: Optional[list[str]] = None,
extra_ignored_values: Optional[Sequence[str]] = None,
extra_allowed_values: Optional[Sequence[str]] = None,
allow_unknown_values: bool = False,
deprecated: bool = False,
default_list_name: Optional[str] = None,
headers: list[str] = [SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER],
headers: Optional[Sequence[str]] = None,
no_update_file: bool = False,
pluralize_lists: Optional[list[str]] = None,
):
pluralize_lists: Optional[Sequence[str]] = None,
) -> Callable[[], None]:
"""
Initialize a cursorless settings csv, creating it if necessary, and watch
for changes to the csv. Talon lists will be generated based on the keys of
Expand All @@ -91,21 +96,21 @@ def init_csv_and_watch_changes(
`cursorles-settings` dir
default_values (ListToSpokenForms): The default values for the lists to
be customized in the given csv
handle_new_values (Optional[Callable[[list[SpokenFormEntry]], None]]): A
handle_new_values (Optional[Callable[[Sequence[SpokenFormEntry]], None]]): A
callback to be called when the lists are updated
extra_ignored_values (Optional[list[str]]): Don't throw an exception if
extra_ignored_values (Optional[Sequence[str]]): Don't throw an exception if
any of these appear as values; just ignore them and don't add them
to any list
allow_unknown_values (bool): If unknown values appear, just put them in
the list
default_list_name (Optional[str]): If unknown values are
allowed, put any unknown values in this list
headers (list[str]): The headers to use for the csv
headers (Optional[Sequence[str]]): The headers to use for the csv
no_update_file (bool): Set this to `True` to indicate that we should not
update the csv. This is used generally in case there was an issue
coming up with the default set of values so we don't want to persist
those to disk
pluralize_lists (list[str]): Create plural version of given lists
pluralize_lists (Optional[Sequence[str]]): Create plural version of given lists
"""
# Don't allow both `extra_allowed_values` and `allow_unknown_values`
assert not (extra_allowed_values and allow_unknown_values)
Expand All @@ -116,6 +121,8 @@ def init_csv_and_watch_changes(
(extra_allowed_values or allow_unknown_values) and not default_list_name
)

if headers is None:
headers = (SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER)
if extra_ignored_values is None:
extra_ignored_values = []
if extra_allowed_values is None:
Expand All @@ -137,7 +144,7 @@ def init_csv_and_watch_changes(
check_for_duplicates(filename, default_values)
create_default_vocabulary_dicts(default_values, pluralize_lists)

def on_watch(path, flags):
def on_watch(path: str, _flags) -> None:
if file_path.match(path):
current_values, has_errors = read_file(
path=file_path,
Expand Down Expand Up @@ -194,16 +201,16 @@ def on_watch(path, flags):
handle_new_values=handle_new_values,
)

def unsubscribe():
def unsubscribe() -> None:
fs.unwatch(file_path.parent, on_watch)

return unsubscribe


def check_for_duplicates(filename, default_values):
def check_for_duplicates(filename: str, default_values: ListToSpokenForms):
results_map = {}
for list_name, dict in default_values.items():
for key, value in dict.items():
for list_name, values in default_values.items():
for key, value in values.items():
if value in results_map:
existing_list_name = results_map[value]
warning = f"WARNING ({filename}): Value `{value}` duplicated between lists '{existing_list_name}' and '{list_name}'"
Expand All @@ -213,16 +220,17 @@ def check_for_duplicates(filename, default_values):
results_map[value] = list_name


def is_removed(value: str):
def is_removed(value: str) -> bool:
return value.startswith("-")


def create_default_vocabulary_dicts(
default_values: dict[str, dict], pluralize_lists: list[str]
default_values: ListToSpokenForms,
pluralize_lists: Sequence[str],
):
default_values_updated = {}
for key, value in default_values.items():
updated_dict = {}
updated_dict: dict[str, str] = {}
for key2, value2 in value.items():
# Enable deactivated(prefixed with a `-`) items
active_key = key2[1:] if key2.startswith("-") else key2
Expand All @@ -235,17 +243,17 @@ def create_default_vocabulary_dicts(
def update_dicts(
default_values: ListToSpokenForms,
current_values: dict[str, str],
extra_ignored_values: list[str],
extra_allowed_values: list[str],
extra_ignored_values: Sequence[str],
extra_allowed_values: Sequence[str],
allow_unknown_values: bool,
default_list_name: str | None,
pluralize_lists: list[str],
handle_new_values: Callable[[list[SpokenFormEntry]], None] | None,
):
pluralize_lists: Sequence[str],
handle_new_values: Callable[[Sequence[SpokenFormEntry]], None] | None,
) -> None:
# Create map with all default values
results_map: dict[str, ResultsListEntry] = {}
for list_name, obj in default_values.items():
for spoken, id in obj.items():
for list_name, values in default_values.items():
for spoken, id in values.items():
results_map[id] = {"spoken": spoken, "id": id, "list": list_name}

# Update result with current values
Expand Down Expand Up @@ -281,13 +289,9 @@ def update_dicts(
handle_new_values(spoken_form_entries)


class ResultsListEntry(TypedDict):
spoken: str
id: str
list: str


def generate_spoken_forms(results_list: Iterable[ResultsListEntry]):
def generate_spoken_forms(
results_list: Iterable[ResultsListEntry],
) -> Iterable[SpokenFormEntry]:
for obj in results_list:
id = obj["id"]
spoken = obj["spoken"]
Expand Down Expand Up @@ -315,25 +319,25 @@ def generate_spoken_forms(results_list: Iterable[ResultsListEntry]):
def assign_lists_to_context(
ctx: Context,
lists: ListToSpokenForms,
pluralize_lists: list[str],
):
for list_name, dict in lists.items():
pluralize_lists: Sequence[str],
) -> None:
for list_name, values in lists.items():
list_singular_name = get_cursorless_list_name(list_name)
ctx.lists[list_singular_name] = dict
ctx.lists[list_singular_name] = values
if list_name in pluralize_lists:
list_plural_name = f"{list_singular_name}_plural"
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in dict.items()}
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in values.items()}


def update_file(
path: Path,
headers: list[str],
headers: Sequence[str],
default_values: dict[str, str],
extra_ignored_values: list[str],
extra_allowed_values: list[str],
extra_ignored_values: Sequence[str],
extra_allowed_values: Sequence[str],
allow_unknown_values: bool,
no_update_file: bool,
):
) -> dict[str, str]:
current_values, has_errors = read_file(
path=path,
headers=headers,
Expand All @@ -344,7 +348,7 @@ def update_file(
)
current_identifiers = current_values.values()

missing = {}
missing: dict[str, str] = {}
for key, value in default_values.items():
if value not in current_identifiers:
missing[key] = value
Expand All @@ -357,16 +361,17 @@ def update_file(
)
else:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
missing_items = sorted(missing.items())
lines = [
f"# {timestamp} - New entries automatically added by cursorless",
*[create_line(key, missing[key]) for key in sorted(missing)],
*[create_line(key, value) for key, value in missing_items],
]
with open(path, "a") as f:
f.write("\n\n" + "\n".join(lines))

print(f"New cursorless features added to {path.name}")
for key in sorted(missing):
print(f"{key}: {missing[key]}")
for key, value in missing_items:
print(f"{key}: {value}")
print(
"See release notes for more info: "
"https://github.com/cursorless-dev/cursorless/blob/main/CHANGELOG.md"
Expand All @@ -376,18 +381,22 @@ def update_file(
return current_values


def create_line(*cells: str):
return ", ".join(cells)


def create_file(path: Path, headers: list[str], default_values: dict):
lines = [create_line(key, default_values[key]) for key in sorted(default_values)]
def create_file(
path: Path,
headers: Sequence[str],
default_values: dict[str, str],
) -> None:
lines = [create_line(key, value) for key, value in sorted(default_values.items())]
lines.insert(0, create_line(*headers))
lines.append("")
path.write_text("\n".join(lines))


def csv_error(path: Path, index: int, message: str, value: str):
def create_line(*cells: str) -> str:
return ", ".join(cells)


def csv_error(path: Path, index: int, message: str, value: str) -> None:
"""Check that an expected condition is true

Note that we try to continue reading in this case so cursorless doesn't get bricked
Expand All @@ -402,19 +411,19 @@ def csv_error(path: Path, index: int, message: str, value: str):

def read_file(
path: Path,
headers: list[str],
headers: Sequence[str],
default_identifiers: Container[str],
extra_ignored_values: list[str],
extra_allowed_values: list[str],
extra_ignored_values: Sequence[str],
extra_allowed_values: Sequence[str],
allow_unknown_values: bool,
):
) -> tuple[dict[str, str], bool]:
with open(path) as csv_file:
# Use `skipinitialspace` to allow spaces before quote. `, "a,b"`
csv_reader = csv.reader(csv_file, skipinitialspace=True)
rows = list(csv_reader)

result = {}
used_identifiers = []
result: dict[str, str] = {}
used_identifiers: set[str] = set()
has_errors = False
seen_headers = False

Expand All @@ -427,7 +436,7 @@ def read_file(

if not seen_headers:
seen_headers = True
if row != headers:
if row != list(headers):
has_errors = True
csv_error(path, i, "Malformed header", create_line(*row))
print(f"Expected '{create_line(*headers)}'")
Expand Down Expand Up @@ -461,15 +470,15 @@ def read_file(
continue

result[key] = value
used_identifiers.append(value)
used_identifiers.add(value)

if has_errors:
app.notify("Cursorless settings error; see log")

return result, has_errors


def get_full_path(filename: str):
def get_full_path(filename: str) -> Path:
if not filename.endswith(".csv"):
filename = f"{filename}.csv"

Expand All @@ -484,7 +493,7 @@ def get_full_path(filename: str):
return (settings_directory / filename).resolve()


def get_super_values(values: ListToSpokenForms):
def get_super_values(values: ListToSpokenForms) -> dict[str, str]:
result: dict[str, str] = {}
for value_dict in values.values():
result.update(value_dict)
Expand Down
19 changes: 12 additions & 7 deletions cursorless-talon/src/spoken_forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Callable, Concatenate
from typing import Callable, Concatenate, Sequence

from talon import app, cron, fs, registry

Expand All @@ -26,9 +26,14 @@

def auto_construct_defaults[**P, R](
spoken_forms: dict[str, ListToSpokenForms],
handle_new_values: Callable[[str, list[SpokenFormEntry]], None],
handle_new_values: Callable[[str, Sequence[SpokenFormEntry]], None],
f: Callable[
Concatenate[str, ListToSpokenForms, Callable[[list[SpokenFormEntry]], None], P],
Concatenate[
str,
ListToSpokenForms,
Callable[[Sequence[SpokenFormEntry]], None],
P,
],
R,
],
):
Expand Down Expand Up @@ -94,7 +99,7 @@ def update():
initialized = False

# Maps from csv name to list of SpokenFormEntry
custom_spoken_forms: dict[str, list[SpokenFormEntry]] = {}
custom_spoken_forms: dict[str, Sequence[SpokenFormEntry]] = {}
spoken_forms_output = SpokenFormsOutput()
spoken_forms_output.init()
graphemes_talon_list = get_graphemes_talon_list()
Expand All @@ -116,7 +121,7 @@ def update_spoken_forms_output():
]
)

def handle_new_values(csv_name: str, values: list[SpokenFormEntry]):
def handle_new_values(csv_name: str, values: Sequence[SpokenFormEntry]):
custom_spoken_forms[csv_name] = values
if initialized:
# On first run, we just do one update at the end, so we suppress
Expand Down Expand Up @@ -163,13 +168,13 @@ def handle_new_values(csv_name: str, values: list[SpokenFormEntry]):
),
handle_csv(
"experimental/actions_custom.csv",
headers=[SPOKEN_FORM_HEADER, "VSCode command"],
headers=(SPOKEN_FORM_HEADER, "VSCode command"),
allow_unknown_values=True,
default_list_name="custom_action",
),
handle_csv(
"experimental/regex_scope_types.csv",
headers=[SPOKEN_FORM_HEADER, "Regex"],
headers=(SPOKEN_FORM_HEADER, "Regex"),
allow_unknown_values=True,
default_list_name="custom_regex_scope_type",
pluralize_lists=["custom_regex_scope_type"],
Expand Down
Loading
Loading