diff --git a/cursorless-talon/src/csv_overrides.py b/cursorless-talon/src/csv_overrides.py index af2fbc3822..57d78448d5 100644 --- a/cursorless-talon/src/csv_overrides.py +++ b/cursorless-talon/src/csv_overrides.py @@ -1,11 +1,10 @@ import csv import typing from collections import defaultdict -from collections.abc import Container from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Callable, Iterable, Optional, TypedDict +from typing import Callable, Container, Iterable, Optional, Sequence, TypedDict from talon import Context, Module, actions, app, fs, settings @@ -49,6 +48,12 @@ class SpokenFormEntry: spoken_forms: list[str] +class ResultsListEntry(TypedDict): + spoken: str + id: str + list: str + + def csv_get_ctx(): return ctx @@ -60,17 +65,17 @@ def csv_get_normalized_ctx(): def init_csv_and_watch_changes( filename: str, default_values: ListToSpokenForms, - handle_new_values: Optional[Callable[[list[SpokenFormEntry]], None]] = None, + handle_new_values: Optional[Callable[[Sequence[SpokenFormEntry]], None]] = None, *, - extra_ignored_values: Optional[list[str]] = None, - extra_allowed_values: Optional[list[str]] = None, + extra_ignored_values: Optional[Sequence[str]] = None, + extra_allowed_values: Optional[Sequence[str]] = None, allow_unknown_values: bool = False, deprecated: bool = False, default_list_name: Optional[str] = None, - headers: list[str] = [SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER], + headers: Optional[Sequence[str]] = None, no_update_file: bool = False, - pluralize_lists: Optional[list[str]] = None, -): + pluralize_lists: Optional[Sequence[str]] = None, +) -> Callable[[], None]: """ Initialize a cursorless settings csv, creating it if necessary, and watch for changes to the csv. Talon lists will be generated based on the keys of @@ -91,21 +96,21 @@ def init_csv_and_watch_changes( `cursorles-settings` dir default_values (ListToSpokenForms): The default values for the lists to be customized in the given csv - handle_new_values (Optional[Callable[[list[SpokenFormEntry]], None]]): A + handle_new_values (Optional[Callable[[Sequence[SpokenFormEntry]], None]]): A callback to be called when the lists are updated - extra_ignored_values (Optional[list[str]]): Don't throw an exception if + extra_ignored_values (Optional[Sequence[str]]): Don't throw an exception if any of these appear as values; just ignore them and don't add them to any list allow_unknown_values (bool): If unknown values appear, just put them in the list default_list_name (Optional[str]): If unknown values are allowed, put any unknown values in this list - headers (list[str]): The headers to use for the csv + headers (Optional[Sequence[str]]): The headers to use for the csv no_update_file (bool): Set this to `True` to indicate that we should not update the csv. This is used generally in case there was an issue coming up with the default set of values so we don't want to persist those to disk - pluralize_lists (list[str]): Create plural version of given lists + pluralize_lists (Optional[Sequence[str]]): Create plural version of given lists """ # Don't allow both `extra_allowed_values` and `allow_unknown_values` assert not (extra_allowed_values and allow_unknown_values) @@ -116,6 +121,8 @@ def init_csv_and_watch_changes( (extra_allowed_values or allow_unknown_values) and not default_list_name ) + if headers is None: + headers = (SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER) if extra_ignored_values is None: extra_ignored_values = [] if extra_allowed_values is None: @@ -137,7 +144,7 @@ def init_csv_and_watch_changes( check_for_duplicates(filename, default_values) create_default_vocabulary_dicts(default_values, pluralize_lists) - def on_watch(path, flags): + def on_watch(path: str, _flags) -> None: if file_path.match(path): current_values, has_errors = read_file( path=file_path, @@ -194,16 +201,16 @@ def on_watch(path, flags): handle_new_values=handle_new_values, ) - def unsubscribe(): + def unsubscribe() -> None: fs.unwatch(file_path.parent, on_watch) return unsubscribe -def check_for_duplicates(filename, default_values): +def check_for_duplicates(filename: str, default_values: ListToSpokenForms): results_map = {} - for list_name, dict in default_values.items(): - for key, value in dict.items(): + for list_name, values in default_values.items(): + for key, value in values.items(): if value in results_map: existing_list_name = results_map[value] warning = f"WARNING ({filename}): Value `{value}` duplicated between lists '{existing_list_name}' and '{list_name}'" @@ -213,16 +220,17 @@ def check_for_duplicates(filename, default_values): results_map[value] = list_name -def is_removed(value: str): +def is_removed(value: str) -> bool: return value.startswith("-") def create_default_vocabulary_dicts( - default_values: dict[str, dict], pluralize_lists: list[str] + default_values: ListToSpokenForms, + pluralize_lists: Sequence[str], ): default_values_updated = {} for key, value in default_values.items(): - updated_dict = {} + updated_dict: dict[str, str] = {} for key2, value2 in value.items(): # Enable deactivated(prefixed with a `-`) items active_key = key2[1:] if key2.startswith("-") else key2 @@ -235,17 +243,17 @@ def create_default_vocabulary_dicts( def update_dicts( default_values: ListToSpokenForms, current_values: dict[str, str], - extra_ignored_values: list[str], - extra_allowed_values: list[str], + extra_ignored_values: Sequence[str], + extra_allowed_values: Sequence[str], allow_unknown_values: bool, default_list_name: str | None, - pluralize_lists: list[str], - handle_new_values: Callable[[list[SpokenFormEntry]], None] | None, -): + pluralize_lists: Sequence[str], + handle_new_values: Callable[[Sequence[SpokenFormEntry]], None] | None, +) -> None: # Create map with all default values results_map: dict[str, ResultsListEntry] = {} - for list_name, obj in default_values.items(): - for spoken, id in obj.items(): + for list_name, values in default_values.items(): + for spoken, id in values.items(): results_map[id] = {"spoken": spoken, "id": id, "list": list_name} # Update result with current values @@ -281,13 +289,9 @@ def update_dicts( handle_new_values(spoken_form_entries) -class ResultsListEntry(TypedDict): - spoken: str - id: str - list: str - - -def generate_spoken_forms(results_list: Iterable[ResultsListEntry]): +def generate_spoken_forms( + results_list: Iterable[ResultsListEntry], +) -> Iterable[SpokenFormEntry]: for obj in results_list: id = obj["id"] spoken = obj["spoken"] @@ -315,25 +319,25 @@ def generate_spoken_forms(results_list: Iterable[ResultsListEntry]): def assign_lists_to_context( ctx: Context, lists: ListToSpokenForms, - pluralize_lists: list[str], -): - for list_name, dict in lists.items(): + pluralize_lists: Sequence[str], +) -> None: + for list_name, values in lists.items(): list_singular_name = get_cursorless_list_name(list_name) - ctx.lists[list_singular_name] = dict + ctx.lists[list_singular_name] = values if list_name in pluralize_lists: list_plural_name = f"{list_singular_name}_plural" - ctx.lists[list_plural_name] = {pluralize(k): v for k, v in dict.items()} + ctx.lists[list_plural_name] = {pluralize(k): v for k, v in values.items()} def update_file( path: Path, - headers: list[str], + headers: Sequence[str], default_values: dict[str, str], - extra_ignored_values: list[str], - extra_allowed_values: list[str], + extra_ignored_values: Sequence[str], + extra_allowed_values: Sequence[str], allow_unknown_values: bool, no_update_file: bool, -): +) -> dict[str, str]: current_values, has_errors = read_file( path=path, headers=headers, @@ -344,7 +348,7 @@ def update_file( ) current_identifiers = current_values.values() - missing = {} + missing: dict[str, str] = {} for key, value in default_values.items(): if value not in current_identifiers: missing[key] = value @@ -357,16 +361,17 @@ def update_file( ) else: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + missing_items = sorted(missing.items()) lines = [ f"# {timestamp} - New entries automatically added by cursorless", - *[create_line(key, missing[key]) for key in sorted(missing)], + *[create_line(key, value) for key, value in missing_items], ] with open(path, "a") as f: f.write("\n\n" + "\n".join(lines)) print(f"New cursorless features added to {path.name}") - for key in sorted(missing): - print(f"{key}: {missing[key]}") + for key, value in missing_items: + print(f"{key}: {value}") print( "See release notes for more info: " "https://github.com/cursorless-dev/cursorless/blob/main/CHANGELOG.md" @@ -376,18 +381,22 @@ def update_file( return current_values -def create_line(*cells: str): - return ", ".join(cells) - - -def create_file(path: Path, headers: list[str], default_values: dict): - lines = [create_line(key, default_values[key]) for key in sorted(default_values)] +def create_file( + path: Path, + headers: Sequence[str], + default_values: dict[str, str], +) -> None: + lines = [create_line(key, value) for key, value in sorted(default_values.items())] lines.insert(0, create_line(*headers)) lines.append("") path.write_text("\n".join(lines)) -def csv_error(path: Path, index: int, message: str, value: str): +def create_line(*cells: str) -> str: + return ", ".join(cells) + + +def csv_error(path: Path, index: int, message: str, value: str) -> None: """Check that an expected condition is true Note that we try to continue reading in this case so cursorless doesn't get bricked @@ -402,19 +411,19 @@ def csv_error(path: Path, index: int, message: str, value: str): def read_file( path: Path, - headers: list[str], + headers: Sequence[str], default_identifiers: Container[str], - extra_ignored_values: list[str], - extra_allowed_values: list[str], + extra_ignored_values: Sequence[str], + extra_allowed_values: Sequence[str], allow_unknown_values: bool, -): +) -> tuple[dict[str, str], bool]: with open(path) as csv_file: # Use `skipinitialspace` to allow spaces before quote. `, "a,b"` csv_reader = csv.reader(csv_file, skipinitialspace=True) rows = list(csv_reader) - result = {} - used_identifiers = [] + result: dict[str, str] = {} + used_identifiers: set[str] = set() has_errors = False seen_headers = False @@ -427,7 +436,7 @@ def read_file( if not seen_headers: seen_headers = True - if row != headers: + if row != list(headers): has_errors = True csv_error(path, i, "Malformed header", create_line(*row)) print(f"Expected '{create_line(*headers)}'") @@ -461,7 +470,7 @@ def read_file( continue result[key] = value - used_identifiers.append(value) + used_identifiers.add(value) if has_errors: app.notify("Cursorless settings error; see log") @@ -469,7 +478,7 @@ def read_file( return result, has_errors -def get_full_path(filename: str): +def get_full_path(filename: str) -> Path: if not filename.endswith(".csv"): filename = f"{filename}.csv" @@ -484,7 +493,7 @@ def get_full_path(filename: str): return (settings_directory / filename).resolve() -def get_super_values(values: ListToSpokenForms): +def get_super_values(values: ListToSpokenForms) -> dict[str, str]: result: dict[str, str] = {} for value_dict in values.values(): result.update(value_dict) diff --git a/cursorless-talon/src/spoken_forms.py b/cursorless-talon/src/spoken_forms.py index 92dc5a8bc2..cbb2a2a0ba 100644 --- a/cursorless-talon/src/spoken_forms.py +++ b/cursorless-talon/src/spoken_forms.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Callable, Concatenate +from typing import Callable, Concatenate, Sequence from talon import app, cron, fs, registry @@ -26,9 +26,14 @@ def auto_construct_defaults[**P, R]( spoken_forms: dict[str, ListToSpokenForms], - handle_new_values: Callable[[str, list[SpokenFormEntry]], None], + handle_new_values: Callable[[str, Sequence[SpokenFormEntry]], None], f: Callable[ - Concatenate[str, ListToSpokenForms, Callable[[list[SpokenFormEntry]], None], P], + Concatenate[ + str, + ListToSpokenForms, + Callable[[Sequence[SpokenFormEntry]], None], + P, + ], R, ], ): @@ -94,7 +99,7 @@ def update(): initialized = False # Maps from csv name to list of SpokenFormEntry - custom_spoken_forms: dict[str, list[SpokenFormEntry]] = {} + custom_spoken_forms: dict[str, Sequence[SpokenFormEntry]] = {} spoken_forms_output = SpokenFormsOutput() spoken_forms_output.init() graphemes_talon_list = get_graphemes_talon_list() @@ -116,7 +121,7 @@ def update_spoken_forms_output(): ] ) - def handle_new_values(csv_name: str, values: list[SpokenFormEntry]): + def handle_new_values(csv_name: str, values: Sequence[SpokenFormEntry]): custom_spoken_forms[csv_name] = values if initialized: # On first run, we just do one update at the end, so we suppress @@ -163,13 +168,13 @@ def handle_new_values(csv_name: str, values: list[SpokenFormEntry]): ), handle_csv( "experimental/actions_custom.csv", - headers=[SPOKEN_FORM_HEADER, "VSCode command"], + headers=(SPOKEN_FORM_HEADER, "VSCode command"), allow_unknown_values=True, default_list_name="custom_action", ), handle_csv( "experimental/regex_scope_types.csv", - headers=[SPOKEN_FORM_HEADER, "Regex"], + headers=(SPOKEN_FORM_HEADER, "Regex"), allow_unknown_values=True, default_list_name="custom_regex_scope_type", pluralize_lists=["custom_regex_scope_type"], diff --git a/packages/cursorless-engine/src/test/fixtures/communitySnippets.fixture.ts b/packages/cursorless-engine/src/test/fixtures/communitySnippets.fixture.ts index bd3f585bb3..e1995c1106 100644 --- a/packages/cursorless-engine/src/test/fixtures/communitySnippets.fixture.ts +++ b/packages/cursorless-engine/src/test/fixtures/communitySnippets.fixture.ts @@ -22,7 +22,13 @@ const snippetAfterAction: ActionDescriptor = { snippets: [ { type: "custom", - body: "```\n$0\n```", + languages: [ + "javascript", + "typescript", + "javascriptreact", + "typescriptreact", + ], + body: 'import * as $0 from "$0";', }, ], }, @@ -34,5 +40,5 @@ const snippetAfterAction: ActionDescriptor = { * Talon tests by relying on our recorded test fixtures alone. */ export const communitySnippetsSpokenFormsFixture = [ - spokenFormTest("snip code after air", snippetAfterAction, undefined), + spokenFormTest("snip import star after air", snippetAfterAction, undefined), ];