Skip to content

Commit

Permalink
Add CLI for codegen (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored May 26, 2023
1 parent aab94f2 commit a7c22f2
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 53 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ parts/
sdist/
var/
wheels/
out/

# Unit test / coverage reports
htmlcov/
Expand Down
2 changes: 2 additions & 0 deletions atproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from .xrpc_client.client.async_client import AsyncClient
from .xrpc_client.client.client import Client

__version__ = '0.0.0' # placeholder. Dynamic version from Git Tag
__all__ = [
'__version__',
'AsyncClient',
'Client',
'models',
Expand Down
120 changes: 119 additions & 1 deletion atproto/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,119 @@
# TODO(MarshalX): Soon
from pathlib import Path

import click

from atproto import __version__
from atproto.codegen.clients.generate_async_client import gen_client
from atproto.codegen.models.generator import generate_models
from atproto.codegen.namespaces.generator import generate_namespaces


class AliasedGroup(click.Group):
"""Ref: https://click.palletsprojects.com/en/8.1.x/advanced/"""

def get_command(self, ctx, cmd_name):
rv = click.Group.get_command(self, ctx, cmd_name)
if rv is not None:
return rv
matches = [x for x in self.list_commands(ctx) if x.startswith(cmd_name)]
if not matches:
return None
elif len(matches) == 1:
return click.Group.get_command(self, ctx, matches[0])
ctx.fail(f'Too many matches: {", ".join(sorted(matches))}')

def resolve_command(self, ctx, args):
# always return the full command name
_, cmd, args = super().resolve_command(ctx, args)
return cmd.name, cmd, args


@click.group(cls=AliasedGroup)
@click.version_option(__version__)
@click.pass_context
def atproto_cli(ctx: click.Context):
"""CLI of AT Protocol SDK for Python"""
ctx.ensure_object(dict)


@atproto_cli.group(cls=AliasedGroup)
@click.option('--lexicon-dir', type=click.Path(exists=True), default=None, help='Path to dir with .JSON lexicon files.')
@click.pass_context
def gen(ctx: click.Context, lexicon_dir):
if lexicon_dir:
lexicon_dir = Path(lexicon_dir)
ctx.obj['lexicon_dir'] = lexicon_dir


@gen.command(name='all', help='Generated models, namespaces, and async clients with default configs.')
@click.pass_context
def gen_all(_: click.Context):
click.echo('Generating all:')

click.echo('- models...')
_gen_models()
click.echo('- namespaces...')
_gen_namespaces()
click.echo('- async clients...')
_gen_async_version()

click.echo('Done!')


def _gen_models(lexicon_dir=None, output_dir=None):
generate_models(lexicon_dir, output_dir)


def _gen_namespaces(lexicon_dir=None, output_dir=None, async_filename=None, sync_filename=None):
generate_namespaces(lexicon_dir, output_dir, async_filename, sync_filename)


def _gen_async_version():
gen_client('client.py', 'async_client.py')


@gen.command(name='models')
@click.option('--output-dir', type=click.Path(exists=True), default=None)
@click.pass_context
def gen_models(ctx: click.Context, output_dir):
click.echo('Generating models...')

if output_dir:
# FIXME(MarshalX)
output_dir = Path(output_dir)
click.secho(
"It doesn't work with '--output-dir' option very well because of hardcoded imports! Replace by yourself",
fg='red',
)

_gen_models(ctx.obj.get('lexicon_dir'), output_dir)

click.echo('Done!')


@gen.command(name='namespaces')
@click.option('--output-dir', type=click.Path(exists=True), default=None)
@click.option('--async-filename', type=click.STRING, default=None, help='Should end with ".py".')
@click.option('--sync-filename', type=click.STRING, default=None, help='Should end with ".py".')
@click.pass_context
def gen_namespaces(ctx: click.Context, output_dir, async_filename, sync_filename):
click.echo('Generating namespaces...')

if output_dir:
output_dir = Path(output_dir)

_gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir, async_filename, sync_filename)

click.echo('Done!')


@gen.command(name='async')
@click.pass_context
def gen_async_version(_: click.Context):
click.echo('Generating async clients...')
_gen_async_version()
click.echo('Done!')


if __name__ == '__main__':
atproto_cli()
4 changes: 0 additions & 4 deletions atproto/codegen/clients/generate_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,3 @@ def gen_client(input_filename: str, output_filename: str) -> None:

write_code(_CLIENT_DIR.joinpath(output_filename), code)
format_code(_CLIENT_DIR.joinpath(output_filename))


if __name__ == '__main__':
gen_client('client.py', 'async_client.py')
29 changes: 23 additions & 6 deletions atproto/codegen/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
from pathlib import Path

from atproto.lexicon import models
from atproto.lexicon.parser import lexicon_parse_dir
Expand Down Expand Up @@ -34,6 +35,22 @@
LexDB = t.Dict[NSID, LexDefs]


class _LexiconDir:
dir_path: t.Optional[Path]

def __init__(self, default_path: Path = None):
self.dir_path = default_path

def set(self, path: Path):
self.dir_path = path

def get(self) -> Path:
return self.dir_path


lexicon_dir = _LexiconDir()


def _filter_defs_by_type(defs: t.Dict[str, models.LexDefinition], def_types: set) -> LexDefs:
return {k: v for k, v in defs.items() if v.type in def_types}

Expand All @@ -51,27 +68,27 @@ def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set)


def build_params_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_PARAMS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_PARAMS)


def build_data_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_DATA)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DATA)


def build_response_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_RESPONSES)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RESPONSES)


def build_def_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_DEF)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DEF)


def build_record_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_RECORDS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RECORDS)


def build_refs_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_REFS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_REFS)


if __name__ == '__main__':
Expand Down
22 changes: 12 additions & 10 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _generate_init_files(root_package_path: Path) -> None:
if dir_name.startswith('__'):
continue

import_parts = root.parts[root.joinpath(dir_name).parts.index('xrpc_client') :]
import_parts = root.parts[root.joinpath(dir_name).parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

if dir_name in {'app', 'com'}:
Expand All @@ -474,7 +474,7 @@ def _generate_init_files(root_package_path: Path) -> None:
if file_name.startswith('__'):
continue

import_parts = root.parts[root.parts.index('xrpc_client') :]
import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

import_lines.append(f'from atproto.{from_import} import {file_name[:-3]}')
Expand Down Expand Up @@ -529,7 +529,7 @@ def _generate_import_aliases(root_package_path: Path):
if file.startswith('.') or file.startswith('__') or file.endswith('.pyc'):
continue

import_parts = root.parts[root.parts.index('xrpc_client') :]
import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

nsid_parts = list(root.parts[root.parts.index('models') + 1 :]) + file[:-3].split('_')
Expand All @@ -540,7 +540,15 @@ def _generate_import_aliases(root_package_path: Path):
write_code(_MODELS_OUTPUT_DIR.joinpath('__init__.py'), join_code(import_lines))


def generate_models():
def generate_models(lexicon_dir: t.Optional[Path] = None, output_dir: t.Optional[Path] = None):
if lexicon_dir:
builder.lexicon_dir.set(lexicon_dir)

if output_dir:
# TODO(MarshalX): Temp hack for CLI. Pass output_dir everywhere.
global _MODELS_OUTPUT_DIR
_MODELS_OUTPUT_DIR = output_dir

_generate_params_models(builder.build_params_models())
_generate_data_models(builder.build_data_models())
_generate_response_models(builder.build_response_models())
Expand All @@ -556,9 +564,3 @@ def generate_models():
_generate_import_aliases(_MODELS_OUTPUT_DIR)

format_code(_MODELS_OUTPUT_DIR)

print('Done')


if __name__ == '__main__':
generate_models()
4 changes: 2 additions & 2 deletions atproto/codegen/namespaces/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def build_namespace_tree(lexicons: t.List[LexiconDoc]) -> dict:
return namespace_tree


def build_namespaces() -> dict:
lexicons = lexicon_parse_dir()
def build_namespaces(lexicon_dir=None) -> dict:
lexicons = lexicon_parse_dir(lexicon_dir)
namespace_tree = build_namespace_tree(lexicons)

return namespace_tree
Expand Down
45 changes: 28 additions & 17 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from atproto.nsid import NSID

_NAMESPACES_OUTPUT_DIR = Path(__file__).parent.parent.parent.joinpath('xrpc_client', 'namespaces')
_NAMESPACES_CLIENT_FILE_PATH = _NAMESPACES_OUTPUT_DIR.joinpath('client', 'raw.py')

_NAMESPACES_SYNC_FILENAME = 'sync_ns.py'
_NAMESPACES_ASYNC_FILENAME = 'async_ns.py'
Expand Down Expand Up @@ -111,7 +110,7 @@ def _get_method_docstring(method_info: MethodInfo) -> str:

doc_string.append(f'{_(2)}Returns:')

return_type = _get_namespace_method_return_type(method_info)
return_type, __ = _get_namespace_method_return_type(method_info)

return_type_desc = 'Output model'
if return_type == 'bool':
Expand Down Expand Up @@ -167,7 +166,7 @@ def _override_arg_line(name: str, model_name: str) -> str:

lines.append(f"{_(2)}response = {c}self._client.{method_name}({invoke_args_str})")

return_type = _get_namespace_method_return_type(method_info)
return_type, __ = _get_namespace_method_return_type(method_info)
lines.append(f"{_(2)}return get_response_model(response, {return_type})")

return join_code(lines)
Expand Down Expand Up @@ -247,30 +246,35 @@ def is_optional_arg(lex_obj) -> bool:
return ', '.join(args)


def _get_namespace_method_return_type(method_info: MethodInfo) -> str:
def _get_namespace_method_return_type(method_info: MethodInfo) -> t.Tuple[str, bool]:
model_name_suffix = ''
if method_info.definition.output and isinstance(method_info.definition.output.schema, LexRef):
# fix collisions with type aliases
# example of collisions: com.atproto.admin.getRepo, com.atproto.sync.getRepo
# could be solved by separating models into different folders using segments of NSID
model_name_suffix = 'Ref'

is_model = False
return_type = 'bool' # return success of response
if method_info.definition.output:
# example of methods without response: app.bsky.graph.muteActor, app.bsky.graph.muteActor
is_model = True
return_type = f'models.{get_import_path(method_info.nsid)}.{OUTPUT_MODEL}{model_name_suffix}'

return return_type
return return_type, is_model


def _get_namespace_method_signature(method_info: MethodInfo, *, sync: bool) -> str:
d, c = get_sync_async_keywords(sync=sync)

name = convert_camel_case_to_snake_case(method_info.name)
args = _get_namespace_method_signature_args(method_info)
return_type = _get_namespace_method_return_type(method_info)
return_type, is_model = _get_namespace_method_return_type(method_info)

return f'{_(1)}{d}def {name}({args}) -> {return_type}:'
if is_model:
return_type = f"'{return_type}'"

return f"{_(1)}{d}def {name}({args}) -> {return_type}:"


def _get_namespace_methods_block(methods_info: t.List[MethodInfo], sync: bool) -> str:
Expand Down Expand Up @@ -315,25 +319,32 @@ def _generate_namespace_in_output(namespace_tree: t.Union[dict, list], output: t
output.append(_get_namespace_methods_block(methods, sync=sync))


def generate_namespaces() -> None:
namespace_tree = build_namespaces()
def generate_namespaces(
lexicon_dir: t.Optional[Path] = None,
output_dir: t.Optional[Path] = None,
async_filename: t.Optional[str] = None,
sync_filename: t.Optional[str] = None,
) -> None:
if not output_dir:
output_dir = _NAMESPACES_OUTPUT_DIR
if not async_filename:
async_filename = _NAMESPACES_ASYNC_FILENAME
if not sync_filename:
sync_filename = _NAMESPACES_SYNC_FILENAME

namespace_tree = build_namespaces(lexicon_dir)

for sync in (True, False):
generated_code_lines_buffer = []
_generate_namespace_in_output(namespace_tree, generated_code_lines_buffer, sync=sync)

code = join_code([_get_namespace_imports(), *generated_code_lines_buffer])

filename = _NAMESPACES_SYNC_FILENAME if sync else _NAMESPACES_ASYNC_FILENAME
filepath = _NAMESPACES_OUTPUT_DIR.joinpath(filename)
filename = sync_filename if sync else async_filename
filepath = output_dir.joinpath(filename)

write_code(filepath, code)

# TODO(MarshalX): generate ClientRaw as root of namespaces

format_code(_NAMESPACES_OUTPUT_DIR)
print('Done')


if __name__ == '__main__':
generate_namespaces()
format_code(output_dir)
4 changes: 2 additions & 2 deletions atproto/xrpc_client/client/async_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class AsyncClientRaw(AsyncClientBase):
"""Group all root namespaces"""

com: async_ns.ComNamespace
bsky: async_ns.BskyNamespace
com: 'async_ns.ComNamespace'
bsky: 'async_ns.BskyNamespace'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
Loading

0 comments on commit a7c22f2

Please sign in to comment.