#!/usr/bin/env python
from __future__ import print_function, unicode_literals
import json, sys, os, argparse, itertools, re, shutil
from contextlib import closing

import json_delta

try:
    import msvcrt

    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
except ImportError:
    pass

HEADER_PATTERN = re.compile(
    r"""
        ^(?P<side>[+-]{3})\s
         (?P<filename>.+)\s
         (?P<timestamp>\d{4}-\d{2}-\d{2}\s
                       \d{2}:\d{2}:\d{2}(?:\.\d+)?\s
                       .+)$
    """,
    re.MULTILINE | re.VERBOSE,
)


def udiff_with_headers(diff):
    """If diff can be read as a unified diff with headers, return a
    sequence of dictionaries parsing the header lines.  Otherwise
    return False."""
    first_line = HEADER_PATTERN.match(diff)
    if first_line is None or first_line.group("side") != "---":
        return False
    left = first_line.groupdict()
    del left["side"]
    second_line = HEADER_PATTERN.match(diff[first_line.end() + 1 :])
    if second_line is None or second_line.group("side") != "+++":
        return False
    right = second_line.groupdict()
    del right["side"]
    return (left, right)


def is_udiff(diff):
    """Boolean test whether diff can be read as a unified diff."""
    point = 0
    if diff and diff[point] not in ("+", "-", " ", "\n"):
        return False
    while point < len(diff) - 1:
        if diff[point] == "\n":
            if diff[point + 1] not in ("+", "-", " ", "\n"):
                return False
        point += 1
    return True


def parse_normal(diff):
    """If diff can be read as a normal (JSON-serialized) diff, decode
    and return it.  Otherwise return False."""
    try:
        diff = json_delta._util.decode_json(diff)
    except ValueError:
        return False
    return json_delta._util.check_diff_structure(diff)


def diff_if_any(diff, ns, norm_check=parse_normal):
    """Return diff if it is valid, and update the namespace ns
    appropriately."""
    if ns.normal or not ns.unified:
        norm = norm_check(diff)
        if norm != False:
            if ns.reverse:
                print("Cannot reverse JSON-format patches.", file=sys.stderr)
                raise Exception
            if norm is True:
                norm = []
            ns.normal = True
            ns.unified = False
            return json.dumps(norm)
    if isinstance(diff, bytes):
        diff = json_delta._util.decode_udiff(diff)
    udiff_parse = udiff_with_headers(diff)
    if udiff_parse:
        update_namespace_headers(ns, udiff_parse)
        return json.dumps(diff)
    elif is_udiff(diff):
        ns.unified = True
        ns.diff = diff
        if ns.output is None:
            ns.output = sys.stdout
        return json.dumps(diff)
    print("Cannot parse diff.", file=sys.stderr)
    raise Exception


def update_namespace_headers(ns, udiff_parse):
    ns.unified = True
    left, right = udiff_parse
    if ns.struc is None:
        try:
            with open(struc_path) as f:
                ns.struc = json.load(f)
        except IOError:
            print(("Couldn't read specified file %s" % struc_path), file=sys.stderr)
            raise
    if ns.output is None:
        ns.output = struc_path


def infer(ns):
    if ns.struc is None and ns.diff is None:
        stdin = json_delta._util.read_bytestring(sys.stdin)
        try:
            udiff_parse = udiff_with_headers(json_delta._util.decode_udiff(stdin))
        except (ValueError, UnicodeError):
            udiff_parse = None

        if udiff_parse:
            diff = stdin
            update_namespace_headers(ns, udiff_parse)
            struc = ns.struc
        else:
            try:
                [struc, diff] = json_delta._util.decode_json(stdin)
            except:
                print(
                    "Couldn't read structure and diff from standard input",
                    file=sys.stderr,
                )
                raise
            ns.diff = diff_if_any(diff, ns, json_delta._util.check_diff_structure)
            struc = None
            diff = None
            both = stdin
    elif ns.diff is None:
        struc = ns.struc
        diff = diff_if_any(json_delta._util.read_bytestring(sys.stdin), ns)
        both = None
    else:
        struc = ns.struc
        diff = diff_if_any(json_delta._util.read_bytestring(ns.diff), ns)
        both = None
    return struc, diff, both


def trim_path(path, slashes):
    """Trim the shortest prefix containing a number of path separators
    equal to slashes from path.

    This function is supposed to mimic the behaviour of the -p option
    to patch: sequences of adjacent slashes are treated as one slash,
    and if there are not enough slashes in path, it is returned
    unchanged (but in a normalized form).

    (Note: the following doctests will fail on a platform where os.sep
    != '/'!)

    >>> trim_path('/foo/bar/baz', 1)
    'foo/bar/baz'
    >>> trim_path('/foo/bar/baz', 2)
    'bar/baz'
    >>> trim_path('/foo/bar/baz', 3)
    'baz'
    >>> trim_path('/foo//bar/baz', 4)
    '/foo/bar/baz'
    """
    path = os.path.normpath(path)
    elems = path.split(os.sep)
    if len(elems) <= slashes:
        return os.sep.join(elems)
    return os.sep.join(elems[slashes:])


def main():
    parser = argparse.ArgumentParser(
        prog="json_patch",
        description="""Applies a patch of the form generated by json_diff to a
        JSON-serialized data structure.  If no arguments are
        specified, stdin will be expected to be a JSON structure: this
        could be a pair [struc, diff], in which case the output will
        be written to stdout.  Or it could be a unified diff with
        headers, as output by json_delta -u.  In this case, the
        program reads the name of the file to patch against out of the
        headers.  Output, however, is always written to the file
        specified by the '--output' option (stdout by default).
        """,
        epilog="Report bugs to <himself@phil-roberts.name>",
    )
    parser.add_argument(
        "struc",
        nargs="?",
        type=argparse.FileType("rb"),
        help="Filename for the structure to modify.",
    )
    parser.add_argument(
        "diff",
        nargs="?",
        type=argparse.FileType("rb"),
        help=("Filename for the diff to apply.  Will read from stdin " "by default."),
    )
    parser.add_argument(
        "--output",
        "-o",
        type=argparse.FileType("wb"),
        help="Filename to output to.  Standard output is the default.",
        metavar="FILE",
        default=sys.stdout,
    )
    parser.add_argument(
        "--strip",
        "-p",
        type=int,
        metavar="NUM",
        default=0,
        help="Strip NUM leading components from file names.",
    )
    parser.add_argument(
        "--reverse",
        "-R",
        action="store_true",
        help="Assume patches were created with old and new files swapped.",
    )
    version = """%(prog)s - part of json-delta {}
Copyright 2012-2020 Philip J. Roberts <himself@phil-roberts.name>. 
BSD License applies; see http://opensource.org/licenses/BSD-2-Clause
""".format(
        json_delta.__VERSION__
    )
    parser.add_argument(
        "--encoding",
        help="Select encoding for output (default UTF-8).",
        default="UTF-8",
    )
    parser.add_argument("--version", action="version", version=version)

    format_grp = parser.add_mutually_exclusive_group()
    format_grp.add_argument(
        "--unified",
        "-u",
        action="store_true",
        help=(
            "Interpret the patch as a unified difference "
            "(As emitted by 'json_diff -u')."
        ),
    )
    format_grp.add_argument(
        "--normal",
        "-n",
        action="store_true",
        help=("Interpret the patch as a normal (i.e. JSON-serialized) " "difference."),
    )
    ns = parser.parse_args()
    struc, diff, both = infer(ns)
    with closing(ns.output) as out_f:
        out_f = getattr(out_f, "buffer", out_f)
        if ns.unified:
            output = json_delta.load_and_upatch(struc, diff, both, reverse=ns.reverse)
        else:
            output = json_delta.load_and_patch(struc, diff, both)
        output = json_delta._util.compact_json_dumps(output)
        out_f.write(output.encode(ns.encoding))
    return 0


if __name__ == "__main__":
    sys.exit(main())
