Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI for codegen #43

Merged
merged 1 commit into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ parts/
sdist/
var/
wheels/
out/

# Unit test / coverage reports
htmlcov/
Expand Down
2 changes: 2 additions & 0 deletions atproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from .xrpc_client.client.async_client import AsyncClient
from .xrpc_client.client.client import Client

__version__ = '0.0.0' # placeholder. Dynamic version from Git Tag
__all__ = [
'__version__',
'AsyncClient',
'Client',
'models',
Expand Down
120 changes: 119 additions & 1 deletion atproto/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,119 @@
# TODO(MarshalX): Soon
from pathlib import Path

import click

from atproto import __version__
from atproto.codegen.clients.generate_async_client import gen_client
from atproto.codegen.models.generator import generate_models
from atproto.codegen.namespaces.generator import generate_namespaces


class AliasedGroup(click.Group):
"""Ref: https://click.palletsprojects.com/en/8.1.x/advanced/"""

def get_command(self, ctx, cmd_name):
rv = click.Group.get_command(self, ctx, cmd_name)
if rv is not None:
return rv
matches = [x for x in self.list_commands(ctx) if x.startswith(cmd_name)]
if not matches:
return None
elif len(matches) == 1:
return click.Group.get_command(self, ctx, matches[0])
ctx.fail(f'Too many matches: {", ".join(sorted(matches))}')

def resolve_command(self, ctx, args):
# always return the full command name
_, cmd, args = super().resolve_command(ctx, args)
return cmd.name, cmd, args


@click.group(cls=AliasedGroup)
@click.version_option(__version__)
@click.pass_context
def atproto_cli(ctx: click.Context):
"""CLI of AT Protocol SDK for Python"""
ctx.ensure_object(dict)


@atproto_cli.group(cls=AliasedGroup)
@click.option('--lexicon-dir', type=click.Path(exists=True), default=None, help='Path to dir with .JSON lexicon files.')
@click.pass_context
def gen(ctx: click.Context, lexicon_dir):
if lexicon_dir:
lexicon_dir = Path(lexicon_dir)
ctx.obj['lexicon_dir'] = lexicon_dir


@gen.command(name='all', help='Generated models, namespaces, and async clients with default configs.')
@click.pass_context
def gen_all(_: click.Context):
click.echo('Generating all:')

click.echo('- models...')
_gen_models()
click.echo('- namespaces...')
_gen_namespaces()
click.echo('- async clients...')
_gen_async_version()

click.echo('Done!')


def _gen_models(lexicon_dir=None, output_dir=None):
generate_models(lexicon_dir, output_dir)


def _gen_namespaces(lexicon_dir=None, output_dir=None, async_filename=None, sync_filename=None):
generate_namespaces(lexicon_dir, output_dir, async_filename, sync_filename)


def _gen_async_version():
gen_client('client.py', 'async_client.py')


@gen.command(name='models')
@click.option('--output-dir', type=click.Path(exists=True), default=None)
@click.pass_context
def gen_models(ctx: click.Context, output_dir):
click.echo('Generating models...')

if output_dir:
# FIXME(MarshalX)
output_dir = Path(output_dir)
click.secho(
"It doesn't work with '--output-dir' option very well because of hardcoded imports! Replace by yourself",
fg='red',
)

_gen_models(ctx.obj.get('lexicon_dir'), output_dir)

click.echo('Done!')


@gen.command(name='namespaces')
@click.option('--output-dir', type=click.Path(exists=True), default=None)
@click.option('--async-filename', type=click.STRING, default=None, help='Should end with ".py".')
@click.option('--sync-filename', type=click.STRING, default=None, help='Should end with ".py".')
@click.pass_context
def gen_namespaces(ctx: click.Context, output_dir, async_filename, sync_filename):
click.echo('Generating namespaces...')

if output_dir:
output_dir = Path(output_dir)

_gen_namespaces(ctx.obj.get('lexicon_dir'), output_dir, async_filename, sync_filename)

click.echo('Done!')


@gen.command(name='async')
@click.pass_context
def gen_async_version(_: click.Context):
click.echo('Generating async clients...')
_gen_async_version()
click.echo('Done!')


if __name__ == '__main__':
atproto_cli()
4 changes: 0 additions & 4 deletions atproto/codegen/clients/generate_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,3 @@ def gen_client(input_filename: str, output_filename: str) -> None:

write_code(_CLIENT_DIR.joinpath(output_filename), code)
format_code(_CLIENT_DIR.joinpath(output_filename))


if __name__ == '__main__':
gen_client('client.py', 'async_client.py')
29 changes: 23 additions & 6 deletions atproto/codegen/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
from pathlib import Path

from atproto.lexicon import models
from atproto.lexicon.parser import lexicon_parse_dir
Expand Down Expand Up @@ -34,6 +35,22 @@
LexDB = t.Dict[NSID, LexDefs]


class _LexiconDir:
dir_path: t.Optional[Path]

def __init__(self, default_path: Path = None):
self.dir_path = default_path

def set(self, path: Path):
self.dir_path = path

def get(self) -> Path:
return self.dir_path


lexicon_dir = _LexiconDir()


def _filter_defs_by_type(defs: t.Dict[str, models.LexDefinition], def_types: set) -> LexDefs:
return {k: v for k, v in defs.items() if v.type in def_types}

Expand All @@ -51,27 +68,27 @@ def _build_nsid_to_defs_map(lexicons: t.List[models.LexiconDoc], def_types: set)


def build_params_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_PARAMS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_PARAMS)


def build_data_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_DATA)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DATA)


def build_response_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_RESPONSES)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RESPONSES)


def build_def_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_DEF)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_DEF)


def build_record_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_RECORDS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RECORDS)


def build_refs_models() -> LexDB:
return _build_nsid_to_defs_map(lexicon_parse_dir(), _LEX_DEF_TYPES_FOR_REFS)
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_REFS)


if __name__ == '__main__':
Expand Down
22 changes: 12 additions & 10 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _generate_init_files(root_package_path: Path) -> None:
if dir_name.startswith('__'):
continue

import_parts = root.parts[root.joinpath(dir_name).parts.index('xrpc_client') :]
import_parts = root.parts[root.joinpath(dir_name).parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

if dir_name in {'app', 'com'}:
Expand All @@ -474,7 +474,7 @@ def _generate_init_files(root_package_path: Path) -> None:
if file_name.startswith('__'):
continue

import_parts = root.parts[root.parts.index('xrpc_client') :]
import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

import_lines.append(f'from atproto.{from_import} import {file_name[:-3]}')
Expand Down Expand Up @@ -529,7 +529,7 @@ def _generate_import_aliases(root_package_path: Path):
if file.startswith('.') or file.startswith('__') or file.endswith('.pyc'):
continue

import_parts = root.parts[root.parts.index('xrpc_client') :]
import_parts = root.parts[root.parts.index(_MODELS_OUTPUT_DIR.name) :]
from_import = '.'.join(import_parts)

nsid_parts = list(root.parts[root.parts.index('models') + 1 :]) + file[:-3].split('_')
Expand All @@ -540,7 +540,15 @@ def _generate_import_aliases(root_package_path: Path):
write_code(_MODELS_OUTPUT_DIR.joinpath('__init__.py'), join_code(import_lines))


def generate_models():
def generate_models(lexicon_dir: t.Optional[Path] = None, output_dir: t.Optional[Path] = None):
if lexicon_dir:
builder.lexicon_dir.set(lexicon_dir)

if output_dir:
# TODO(MarshalX): Temp hack for CLI. Pass output_dir everywhere.
global _MODELS_OUTPUT_DIR
_MODELS_OUTPUT_DIR = output_dir

_generate_params_models(builder.build_params_models())
_generate_data_models(builder.build_data_models())
_generate_response_models(builder.build_response_models())
Expand All @@ -556,9 +564,3 @@ def generate_models():
_generate_import_aliases(_MODELS_OUTPUT_DIR)

format_code(_MODELS_OUTPUT_DIR)

print('Done')


if __name__ == '__main__':
generate_models()
4 changes: 2 additions & 2 deletions atproto/codegen/namespaces/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def build_namespace_tree(lexicons: t.List[LexiconDoc]) -> dict:
return namespace_tree


def build_namespaces() -> dict:
lexicons = lexicon_parse_dir()
def build_namespaces(lexicon_dir=None) -> dict:
lexicons = lexicon_parse_dir(lexicon_dir)
namespace_tree = build_namespace_tree(lexicons)

return namespace_tree
Expand Down
45 changes: 28 additions & 17 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from atproto.nsid import NSID

_NAMESPACES_OUTPUT_DIR = Path(__file__).parent.parent.parent.joinpath('xrpc_client', 'namespaces')
_NAMESPACES_CLIENT_FILE_PATH = _NAMESPACES_OUTPUT_DIR.joinpath('client', 'raw.py')

_NAMESPACES_SYNC_FILENAME = 'sync_ns.py'
_NAMESPACES_ASYNC_FILENAME = 'async_ns.py'
Expand Down Expand Up @@ -111,7 +110,7 @@ def _get_method_docstring(method_info: MethodInfo) -> str:

doc_string.append(f'{_(2)}Returns:')

return_type = _get_namespace_method_return_type(method_info)
return_type, __ = _get_namespace_method_return_type(method_info)

return_type_desc = 'Output model'
if return_type == 'bool':
Expand Down Expand Up @@ -167,7 +166,7 @@ def _override_arg_line(name: str, model_name: str) -> str:

lines.append(f"{_(2)}response = {c}self._client.{method_name}({invoke_args_str})")

return_type = _get_namespace_method_return_type(method_info)
return_type, __ = _get_namespace_method_return_type(method_info)
lines.append(f"{_(2)}return get_response_model(response, {return_type})")

return join_code(lines)
Expand Down Expand Up @@ -247,30 +246,35 @@ def is_optional_arg(lex_obj) -> bool:
return ', '.join(args)


def _get_namespace_method_return_type(method_info: MethodInfo) -> str:
def _get_namespace_method_return_type(method_info: MethodInfo) -> t.Tuple[str, bool]:
model_name_suffix = ''
if method_info.definition.output and isinstance(method_info.definition.output.schema, LexRef):
# fix collisions with type aliases
# example of collisions: com.atproto.admin.getRepo, com.atproto.sync.getRepo
# could be solved by separating models into different folders using segments of NSID
model_name_suffix = 'Ref'

is_model = False
return_type = 'bool' # return success of response
if method_info.definition.output:
# example of methods without response: app.bsky.graph.muteActor, app.bsky.graph.muteActor
is_model = True
return_type = f'models.{get_import_path(method_info.nsid)}.{OUTPUT_MODEL}{model_name_suffix}'

return return_type
return return_type, is_model


def _get_namespace_method_signature(method_info: MethodInfo, *, sync: bool) -> str:
d, c = get_sync_async_keywords(sync=sync)

name = convert_camel_case_to_snake_case(method_info.name)
args = _get_namespace_method_signature_args(method_info)
return_type = _get_namespace_method_return_type(method_info)
return_type, is_model = _get_namespace_method_return_type(method_info)

return f'{_(1)}{d}def {name}({args}) -> {return_type}:'
if is_model:
return_type = f"'{return_type}'"

return f"{_(1)}{d}def {name}({args}) -> {return_type}:"


def _get_namespace_methods_block(methods_info: t.List[MethodInfo], sync: bool) -> str:
Expand Down Expand Up @@ -315,25 +319,32 @@ def _generate_namespace_in_output(namespace_tree: t.Union[dict, list], output: t
output.append(_get_namespace_methods_block(methods, sync=sync))


def generate_namespaces() -> None:
namespace_tree = build_namespaces()
def generate_namespaces(
lexicon_dir: t.Optional[Path] = None,
output_dir: t.Optional[Path] = None,
async_filename: t.Optional[str] = None,
sync_filename: t.Optional[str] = None,
) -> None:
if not output_dir:
output_dir = _NAMESPACES_OUTPUT_DIR
if not async_filename:
async_filename = _NAMESPACES_ASYNC_FILENAME
if not sync_filename:
sync_filename = _NAMESPACES_SYNC_FILENAME

namespace_tree = build_namespaces(lexicon_dir)

for sync in (True, False):
generated_code_lines_buffer = []
_generate_namespace_in_output(namespace_tree, generated_code_lines_buffer, sync=sync)

code = join_code([_get_namespace_imports(), *generated_code_lines_buffer])

filename = _NAMESPACES_SYNC_FILENAME if sync else _NAMESPACES_ASYNC_FILENAME
filepath = _NAMESPACES_OUTPUT_DIR.joinpath(filename)
filename = sync_filename if sync else async_filename
filepath = output_dir.joinpath(filename)

write_code(filepath, code)

# TODO(MarshalX): generate ClientRaw as root of namespaces

format_code(_NAMESPACES_OUTPUT_DIR)
print('Done')


if __name__ == '__main__':
generate_namespaces()
format_code(output_dir)
4 changes: 2 additions & 2 deletions atproto/xrpc_client/client/async_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class AsyncClientRaw(AsyncClientBase):
"""Group all root namespaces"""

com: async_ns.ComNamespace
bsky: async_ns.BskyNamespace
com: 'async_ns.ComNamespace'
bsky: 'async_ns.BskyNamespace'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
Loading