#!/usr/bin/env python3
"""Protoc Plugin to generate json interfaces for typescript. Based on dropbox/mypy-protobuf.@"""

import sys
from collections import defaultdict
from contextlib import contextmanager

import google.protobuf.descriptor_pb2 as d_typed
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf import descriptor

MYPY = False
if MYPY:
    from typing import (
        Any,
        Callable,
        Dict,
        Generator,
        List,
        Set,
        Text,
        Tuple,
        cast,
    )
    from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
else:
    # Provide minimal mypy identifiers to make code run without typing module present
    Text = None

    def cast(type, value):
        return value

# Hax to get around fact that google protobuf libraries aren't in typeshed yet
d = d_typed  # type: Any

GENERATED = "@ge" + "nerated"  # So phabricator doesn't think this file is generated
HEADER = "// {} by protoc-gen-json-ts.py.  Do not edit!\n".format(GENERATED)

# Well known types stored as string in JSON.
_WELL_KNOWN_TYPE_AS_STRING = {'.google.protobuf.' + t for t in {
    'Duration', 'FieldMask', 'Timestamp',
}}


class PkgWriter(object):
    """Writes a single d.ts file"""

    def __init__(self, fd, descriptors):
        # type: (d.FileDescriptorProto, Descriptors) -> None
        self.fd = fd
        self.descriptors = descriptors
        self.lines = []  # type: List[Text]
        self.indent = ""

        # dictionary of x->(y,z) for `from {x} import {y} as {z}`
        self.imports = defaultdict(set)  # type: Dict[Text, Set[Tuple[Text, Text]]]
        self.locals = set()  # type: Set[Text]

    def _import(self, path, name):
        # type: (Text, Text) -> Text
        """Imports a stdlib path and returns a handle to it
        eg. self._import("typing", "Optional") -> "Optional"
        """
        imp = path.replace('/', '.')
        mangled_name = imp.replace('.', '___') + '___' + name
        self.imports[imp].add((name, mangled_name))
        return mangled_name

    def _import_message(self, type_name):
        # type: (d.FieldDescriptorProto) -> Text
        """Import a referenced message and return a handle"""
        name = cast(Text, type_name)
        if name[0] == '.' and name[1].isupper():
            # Message defined in this file
            return name[1:]

        message_fd = self.descriptors.message_to_fd[name]
        if message_fd.name == self.fd.name:
            # message defined in this file
            split = name.split('.')
            for i, segment in enumerate(split):
                if segment and segment[0].isupper():
                    return ".".join(split[i:])

        if name in _WELL_KNOWN_TYPE_AS_STRING:
            return "string"

        # message is defined in the same package, in another file
        message_package_path = message_fd.package.split('.')
        if self.fd.package.split('.')[:len(message_package_path)] == message_package_path:
            package_prefix = '.{}.'.format(message_fd.package)
            if name.startswith(package_prefix):
                return name[len(package_prefix):]

        # message is from another package
        if name.startswith('.'):
            return name[1:]

        raise AssertionError(
            'Could not parse local name "{}" from file "{}" and package "{}" while looking at '
            'file "{}" from package "{}"'.format(name, message_fd.name, message_fd.package, self.fd.name, self.fd.package))

    @contextmanager
    def _indent(self):
        # type: () -> Generator
        self.indent = self.indent + "  "
        yield
        self.indent = self.indent[:-2]

    def _write_line(self, line, *args):
        # type: (Text, *Text) -> None
        if line == "":
            self.lines.append(line)
        else:
            self.lines.append(self.indent + line.format(*args))

    def write_package(self, package_path=None):
        # type: (List[str]) -> None
        l = self._write_line
        declare = package_path is None
        if declare:
            package_path = self.fd.package.split('.')
        if not package_path:
            self.write_enums(self.fd.enum_type)
            self.write_messages(self.fd.message_type)
            return
        l("{}namespace {} {{", 'declare ' if declare else '', package_path[0])
        l("")
        with self._indent():
            self.write_package(package_path[1:])
        l("}}")
        l("")

    def write_enums(self, enums):
        # type: (List[d.EnumDescriptorProto]) -> None
        l = self._write_line
        for enum in enums:
            l("type {} =", enum.name)
            with self._indent():
                for val in enum.value:
                    l("| '{}'", val.name)
            l("")

    def write_messages(self, messages, prefix=""):
        # type: (List[d.DescriptorProto], Text) -> None
        l = self._write_line

        for desc in messages:
            self.locals.add(desc.name)
            qualified_name = prefix + desc.name
            l("interface {} {{", desc.name)
            with self._indent():

                # Scalar fields
                for field in [f for f in desc.field if is_scalar(f)]:
                    if field.label == d.FieldDescriptorProto.LABEL_REPEATED:
                        l("readonly {}?: readonly {}[]", field.json_name, self.typescript_type(field))
                    else:
                        l("readonly {}?: {}", field.json_name, self.typescript_type(field))

                # Getters for non-scalar fields
                for field in [f for f in desc.field if not is_scalar(f)]:
                    if field.label == d.FieldDescriptorProto.LABEL_REPEATED:
                        msg = self.descriptors.messages[field.type_name]
                        if msg.options.map_entry:
                            # map generates a special Entry wrapper message
                            l("readonly {}?: {{[key: {}]: {}}}", field.json_name, self.typescript_type(msg.field[0]), self.typescript_type(msg.field[1]))
                        else:
                            l("readonly {}?: readonly {}[]", field.json_name, self.typescript_type(field))
                    else:
                        l("readonly {}?: {}", field.json_name, self.typescript_type(field))
            l("}}")
            l("")

    def typescript_type(self, field):
        # type: (d.FieldDescriptorProto) -> Text
        mapping = {
            d.FieldDescriptorProto.TYPE_DOUBLE: lambda: "number",
            d.FieldDescriptorProto.TYPE_FLOAT: lambda: "number",

            d.FieldDescriptorProto.TYPE_INT64: lambda: "number",
            d.FieldDescriptorProto.TYPE_UINT64: lambda: "number",
            d.FieldDescriptorProto.TYPE_FIXED64: lambda: "number",
            d.FieldDescriptorProto.TYPE_SFIXED64: lambda: "number",
            d.FieldDescriptorProto.TYPE_SINT64: lambda: "number",
            d.FieldDescriptorProto.TYPE_INT32: lambda: "number",
            d.FieldDescriptorProto.TYPE_UINT32: lambda: "number",
            d.FieldDescriptorProto.TYPE_FIXED32: lambda: "number",
            d.FieldDescriptorProto.TYPE_SFIXED32: lambda: "number",
            d.FieldDescriptorProto.TYPE_SINT32: lambda: "number",

            d.FieldDescriptorProto.TYPE_BOOL: lambda: "boolean",
            d.FieldDescriptorProto.TYPE_STRING: lambda: "string",
            d.FieldDescriptorProto.TYPE_BYTES: lambda: "string",

            d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name),
            d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
            d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
        }  # type: Dict[int, Callable[[], Text]]

        assert field.type in mapping, "Unrecognized type: " + field.type
        return mapping[field.type]()

    def write(self):
        # type: () -> Text
        imports = []
        for pkg, items in sorted(self.imports.items()):
            imported_names = ', '.join(
                '{}: {}'.format(name, mangled_name) for name, mangled_name in sorted(items))
            imports.append(u"import {{{}}} from '{}'".format(imported_names, pkg))
        imports.append("")
        return "\n".join(imports + self.lines)


def is_scalar(fd):
    # type: (d.FileDescriptorProto) -> bool
    return not (
        fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or
        fd.type == d.FieldDescriptorProto.TYPE_GROUP
    )


def generate_ts_declarations(descriptors, response, quiet):
    # type: (Descriptors, plugin.CodeGeneratorResponse, bool) -> None
    for name, fd in descriptors.to_generate.items():
        pkg_writer = PkgWriter(fd, descriptors)
        pkg_writer.write_package()

        assert name == fd.name
        assert fd.name.endswith('.proto')
        output = response.file.add()
        output.name = fd.name[:-6].replace('-', '_') + '_pb.d.ts'
        output.content = HEADER + pkg_writer.write()
        if not quiet:
            print("Writing json-ts to", output.name, file=sys.stderr)

class Descriptors(object):

    def __init__(self, request):
        # type: (plugin.CodeGeneratorRequest) -> None
        files = {f.name: f for f in request.proto_file}
        to_generate = {n: files[n] for n in request.file_to_generate}
        self.files = files  # type: Dict[Text, d.FileDescriptorProto]
        self.to_generate = to_generate  # type: Dict[Text, d.FileDescriptorProto]
        self.messages = {} # type: Dict[Text, d.DescriptorProto]
        self.message_to_fd = {}  # type: Dict[Text, d.FileDescriptorProto]

        def _add_enums(enums, prefix, fd):
            # type: (RepeatedCompositeFieldContainer[d.EnumDescriptorProto], Text, d.FileDescriptorProto) -> None
            for enum in enums:
                self.message_to_fd[prefix + enum.name] = fd

        def _add_messages(messages, prefix, fd):
            # type: (RepeatedCompositeFieldContainer[d.DescriptorProto], Text, d.FileDescriptorProto) -> None
            for message in messages:
                self.messages[prefix + message.name] = message
                self.message_to_fd[prefix + message.name] = fd
                sub_prefix = prefix + message.name + "."
                _add_messages(message.nested_type, sub_prefix, fd)
                _add_enums(message.enum_type, sub_prefix, fd)

        for fd in request.proto_file:
            start_prefix = "." + fd.package + "." if fd.package else "."
            _add_messages(fd.message_type, start_prefix, fd)
            _add_enums(fd.enum_type, start_prefix, fd)


def main():
    # type: () -> None
    # Read request message from stdin
    data = sys.stdin.buffer.read()

    # Parse request
    request = plugin.CodeGeneratorRequest()
    request.ParseFromString(data)

    # Create response
    response = plugin.CodeGeneratorResponse()

    # Generate mypy
    generate_ts_declarations(Descriptors(request), response, "quiet" in request.parameter)

    # Serialise response message
    output = response.SerializeToString()

    # Write to stdout
    sys.stdout.buffer.write(output)


if __name__ == '__main__':
    main()
