#  Copyright (c) 2022, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

"""
This file contains a pass for lowering complex dialect ops into core ops.

Steps for adding a new complex dialect op:
1. Add a dialect op in complex_dialect_ops.py
2. Add a corresponding lowering function

In Step 2, notice that when implementing lower functions, we need to specify before_op during
lowering to core ops. It's for both correctness as well as SSA graph's readability, because the
generated core ops should be placed before the ops which were placed after that dialect op.
More specifically, here is the SSA graph before lowering:
    block0() {
      %1 = complex_dialect_op(data=%input)
      %2 = core_op1(x=%1)
      %3 = core_op2(x=%2)
    } -> (%3)
During lowering `complex_dialect_op`, we want all newly generated core ops are placed before the
`core_op1`.
"""

import functools
from typing import Callable, Dict, Optional, Tuple

import numpy as np

from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.operation import Operation
from coremltools.converters.mil.mil.ops.defs.complex_dialect_ops import (
    fft_canonicalize_length_dim,
    fft_canonicalize_shapes_dims,
)
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.var import ComplexVar, Var


class LowerComplex:
    # The map recording each complex dialect op's lowering function.
    _lower_map: Dict[str, Callable] = dict()

    @staticmethod
    def register_lower_func(op_type: str) -> Callable:
        """Register lowering function for complex dialect ops."""

        def lower_func_wrapper(func):
            @functools.wraps(func)
            def wrapper_inner(*args, **kwargs):
                return func(*args, **kwargs)

            if op_type in LowerComplex._lower_map:
                raise ValueError(
                    f"The op {op_type} already got lowering function registered."
                )
            LowerComplex._lower_map[op_type] = func
            return wrapper_inner

        return lower_func_wrapper

    @staticmethod
    def has_lower_func(op_type: str) -> bool:
        """Check if the complex dialect op has corresponding lowering function."""
        return op_type in LowerComplex._lower_map

    @staticmethod
    def get_lower_func(op_type: str) -> Callable:
        """Get the complex dialect op's lowering function."""
        if not LowerComplex.has_lower_func(op_type):
            raise ValueError(
                f"The op {op_type} doesn't have any lowering function registered."
            )
        return LowerComplex._lower_map[op_type]


def _resize_data(
    input_data: Var, dims: Tuple[int], sizes: Tuple[int], before_op: Operation
) -> Var:
    """
    For each dim in `dims`, resize the input data size to corresponding size in `sizes`.
    If the `size` is smaller than the data's size at `dim`, trim the data to `size`.
    If the `size` is larger, pad zeros to make the data reaches `size`.
    """
    for (dim, size) in zip(dims, sizes):
        if size < input_data.shape[dim]:
            indices = mb.range_1d(start=0, end=size, step=1, before_op=before_op)
            input_data = mb.gather(
                x=input_data, indices=indices, axis=dim, before_op=before_op
            )
        elif size > input_data.shape[dim]:
            zero_shape = list(input_data.shape)
            zero_shape[dim] = size - input_data.shape[dim]
            zero_data = mb.fill(shape=zero_shape, value=0.0, before_op=before_op)
            input_data = mb.concat(
                values=[input_data, zero_data], axis=dim, before_op=before_op
            )

    return input_data


def _restore_conj(
    input_data: ComplexVar, n: Var, dim: Var, before_op: Operation
) -> Tuple[Var, Var]:
    """
    The input is interpreted as a one-sided Hermitian signal in the Fourier domain, as produced
    by rfft(). So we need to restore it to the full matrix by following X[i] = conj(X[-i]).
    Real part's conj is itself, and imaginary part's conj is negative of the original value.
    For odd number n, the last element is also included in mirroring input.
    """
    real_data: Var = input_data.real
    imag_data: Var = input_data.imag

    size = 2 * (input_data.real.shape[dim.val] - 1)
    if n is not None and n.val is not None:
        size = n.val
        real_data = _resize_data(
            real_data, dims=(dim.val,), sizes=(size // 2 + 1,), before_op=before_op
        )
        imag_data = _resize_data(
            imag_data, dims=(dim.val,), sizes=(size // 2 + 1,), before_op=before_op
        )

    range_end = (
        real_data.shape[dim.val] - 2 if size % 2 == 0 else real_data.shape[dim.val] - 1
    )
    if range_end > 0:
        mirror_indices = mb.range_1d(
            start=range_end, end=0, step=-1, before_op=before_op
        )
        real_part_mirror_values = mb.gather(
            x=real_data, indices=mirror_indices, axis=dim.val, before_op=before_op
        )
        imag_part_mirror_values = mb.gather(
            x=imag_data, indices=mirror_indices, axis=dim.val, before_op=before_op
        )
        imag_part_mirror_values = mb.mul(
            x=imag_part_mirror_values, y=-1.0, before_op=before_op
        )

        real_data = mb.concat(
            values=[real_data, real_part_mirror_values],
            axis=dim.val,
            before_op=before_op,
        )
        imag_data = mb.concat(
            values=[imag_data, imag_part_mirror_values],
            axis=dim.val,
            before_op=before_op,
        )

    return real_data, imag_data


def _fft_1d(
    input_real: Var,
    input_imag: Var,
    n: Optional[Var],
    dim: Optional[Var],
    norm: Optional[Var],
    before_op: Operation,
    inverse: bool = False,  # For inverse FFT.
) -> Tuple[Var, Var]:
    """
    1-D FFT by DFT Matrix Multiplication.

    The core issue is how to derive the DFT matrix. As the DFT matrix is consist of different powers
    of `w`, where w=e^(2pi/N i), we need to separate the real and imaginary part of w. To achieve
    that, we need to find a way to construct the following matrix (from the power of `w` in DFT):
        0    0    0      ...    0
        0    1    2      ...    N-1
        0    2    4      ...    2(N-1)
        ...    ....      ...
        0   N-1  2(N-1)  ...    (N-1)(N-1)
    This matrix could be derived by outer product of two range tensors.

    After getting that base matrix, we can take sin and cos to get the corresponding `sin_base` and
    `cos_base` matrix. Now based on some math formulas including:
        * The addition of complex numbers is: (a+bi)+(c+di)=(a+c)+(b+d)i.
        * The multiplication of complex numbers is: (a+bi)(c+di)=ac+adi+bci−bd=(ac−bd)+(ad+bc)i.
        * Euler’s formula: e^xi=cosx+isinx.
        * Cosine is an even function: cos(−x)=cosx.
        * Sine is an odd function: sin(−x)=−(sinx).
    We can get
        * The real part output is: cos_base * input_real + sin_base * input_imag
        * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
    That's how we calculate the real and imaginary part separately for the FFT.
    """
    n, dim = fft_canonicalize_length_dim(input_real, n, dim)

    # Swaps target dim axis to the first axis.
    axes = list(range(len(input_real.shape)))
    axes[0] = dim
    axes[dim] = 0
    transposed_input_real = mb.transpose(x=input_real, perm=axes, before_op=before_op)
    transposed_input_imag = mb.transpose(x=input_imag, perm=axes, before_op=before_op)

    # Trim or pad input according to n.
    transposed_input_real = _resize_data(
        input_data=transposed_input_real,
        dims=(0,),
        sizes=(n,),
        before_op=before_op,
    )
    transposed_input_imag = _resize_data(
        input_data=transposed_input_imag,
        dims=(0,),
        sizes=(n,),
        before_op=before_op,
    )

    # Calculate DFT matrix.
    original_shape = transposed_input_real.shape
    N = transposed_input_real.shape[0]
    reshaped_input_real = mb.reshape(
        x=transposed_input_real, shape=[N, -1], before_op=before_op
    )
    reshaped_input_imag = mb.reshape(
        x=transposed_input_imag, shape=[N, -1], before_op=before_op
    )
    tmp = mb.range_1d(start=0, end=N, step=1, before_op=before_op)
    # Use MIL ops to calculate base = torch.outer(tmp, tmp) * (2 * torch.pi / N).
    tmp_x = mb.reshape(x=tmp, shape=[-1, 1], before_op=before_op)
    tmp_y = mb.reshape(x=tmp, shape=[1, -1], before_op=before_op)
    base = mb.matmul(x=tmp_x, y=tmp_y, before_op=before_op)
    base = mb.cast(x=base, dtype="fp32", before_op=before_op)
    base = mb.mul(x=base, y=2 * np.pi, before_op=before_op)
    N = mb.cast(x=N, dtype="fp32", before_op=before_op)
    base = mb.real_div(x=base, y=N, before_op=before_op)
    # Get real part and imaginary part separately.
    cos_base = mb.cos(x=base, before_op=before_op)
    sin_base = mb.sin(x=base, before_op=before_op)

    if not inverse:
        real_part = mb.add(
            x=mb.matmul(x=cos_base, y=reshaped_input_real, before_op=before_op),
            y=mb.matmul(x=sin_base, y=reshaped_input_imag, before_op=before_op),
            before_op=before_op,
        )
        imag_part = mb.sub(
            x=mb.matmul(x=sin_base, y=reshaped_input_real, before_op=before_op),
            y=mb.matmul(x=cos_base, y=reshaped_input_imag, before_op=before_op),
            before_op=before_op,
        )
        imag_part = mb.mul(x=imag_part, y=-1.0, before_op=before_op)
    else:
        real_part = mb.sub(
            x=mb.matmul(x=cos_base, y=reshaped_input_real, before_op=before_op),
            y=mb.matmul(x=sin_base, y=reshaped_input_imag, before_op=before_op),
            before_op=before_op,
        )
        imag_part = mb.add(
            x=mb.matmul(x=sin_base, y=reshaped_input_real, before_op=before_op),
            y=mb.matmul(x=cos_base, y=reshaped_input_imag, before_op=before_op),
            before_op=before_op,
        )

    real_part = mb.reshape(x=real_part, shape=original_shape, before_op=before_op)
    imag_part = mb.reshape(x=imag_part, shape=original_shape, before_op=before_op)

    # Swaps dim back.
    real_part = mb.transpose(x=real_part, perm=axes, before_op=before_op)
    imag_part = mb.transpose(x=imag_part, perm=axes, before_op=before_op)

    # Normalization if needed.
    apply_scale = False
    scale = 1
    if norm.val is not None:
        # For FFT, "forward" means normalize 1/N, while in IFFT, "backward" means normalize 1/N.
        if (not inverse) and (norm.val in ["forward", "ortho"]):
            apply_scale = True
            scale = N if norm.val == "forward" else mb.sqrt(x=N, before_op=before_op)
        if inverse and (norm.val in ["backward", "ortho"]):
            apply_scale = True
            scale = N if norm.val == "backward" else mb.sqrt(x=N, before_op=before_op)
    if apply_scale:
        real_part = mb.real_div(x=real_part, y=scale, before_op=before_op)
        imag_part = mb.real_div(x=imag_part, y=scale, before_op=before_op)

    return real_part, imag_part


def _rfft_1d(
    input_real: Var,
    n: Optional[Var],
    dim: Optional[Var],
    norm: Optional[Var],
    before_op: Operation,
) -> Tuple[Var, Var]:
    """
    It's similar to fft, but as the input is real data, the redundant info (the conjugate part) is
    removed in the result.
    """
    input_imag = mb.fill(
        shape=mb.shape(x=input_real, before_op=before_op),
        value=0.0,
        before_op=before_op,
    )
    real_data, imag_data = _fft_1d(
        input_real, input_imag, n, dim, norm, before_op=before_op
    )
    remain_len = real_data.shape[dim.val] // 2 + 1
    remain_indices = mb.range_1d(start=0, end=remain_len, step=1, before_op=before_op)
    real_data = mb.gather(
        x=real_data, indices=remain_indices, axis=dim.val, before_op=before_op
    )
    imag_data = mb.gather(
        x=imag_data, indices=remain_indices, axis=dim.val, before_op=before_op
    )

    return real_data, imag_data


def _wrap_complex_output(
    original_output: Var, real_data: Var, imag_data: Var
) -> ComplexVar:
    return ComplexVar(
        name=original_output.name + "_lowered",
        sym_type=original_output.sym_type,
        real=real_data,
        imag=imag_data,
    )


@LowerComplex.register_lower_func(op_type="complex")
def _lower_complex(op: Operation):
    return _wrap_complex_output(op.outputs[0], op.real_data, op.imag_data)


@LowerComplex.register_lower_func(op_type="complex_real")
def _lower_complex_real(op: Operation):
    complex_input: ComplexVar = op.data
    # Use an identity op to avoid the block's input name inconsistency issue. If we directly use
    # complex_input.real, the var's name could be inconsistent with the block's input name.
    result = mb.identity(x=complex_input.real, before_op=op)
    return result


@LowerComplex.register_lower_func(op_type="complex_imag")
def _lower_complex_imag(op: Operation):
    complex_input: ComplexVar = op.data
    # Use an identity op to avoid the block's input name inconsistency issue. If we directly use
    # complex_input.imag, the var's name could be inconsistent with the block's input name.
    result = mb.identity(x=complex_input.imag, before_op=op)
    return result


@LowerComplex.register_lower_func(op_type="complex_fft")
def _lower_complex_fft(op: Operation):
    if types.is_complex(op.data.dtype):
        real_data = op.data.real
        imag_data = op.data.imag
    else:
        real_data = op.data
        imag_data = mb.fill(
            shape=mb.shape(x=real_data, before_op=op),
            value=mb.cast(
                x=mb.const(val=0.0, before_op=op),
                dtype=real_data.dtype.__name__,
                before_op=op,
            ),
            before_op=op,
        )
    real_data, imag_data = _fft_1d(
        real_data,
        imag_data,
        op.n,
        op.dim,
        op.norm,
        before_op=op,
    )
    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_fftn")
def _lower_complex_fftn(op: Operation):
    if types.is_complex(op.data.dtype):
        real_data = op.data.real
        imag_data = op.data.imag
    else:
        real_data = op.data
        imag_data = mb.fill(
            shape=mb.shape(x=real_data, before_op=op),
            value=mb.cast(
                x=mb.const(val=0.0, before_op=op),
                dtype=real_data.dtype.__name__,
                before_op=op,
            ),
            before_op=op,
        )

    shapes, dims = fft_canonicalize_shapes_dims(real_data, op.shapes, op.dims)
    for shape, dim in zip(shapes, dims):
        real_data, imag_data = _fft_1d(
            real_data,
            imag_data,
            n=mb.const(val=shape, before_op=op),
            dim=mb.const(val=dim, before_op=op),
            norm=op.norm,
            before_op=op,
        )

    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_rfft")
def _lower_complex_rfft(op: Operation):
    real_data, imag_data = _rfft_1d(op.data, op.n, op.dim, op.norm, before_op=op)
    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_rfftn")
def _lower_complex_rfftn(op: Operation):
    shapes, dims = fft_canonicalize_shapes_dims(op.data, op.shapes, op.dims)
    real_data, imag_data = _rfft_1d(
        op.data,
        mb.const(val=shapes[-1], before_op=op),
        mb.const(val=dims[-1], before_op=op),
        op.norm,
        before_op=op,
    )
    for shape, dim in zip(shapes[:-1], dims[:-1]):
        real_data, imag_data = _fft_1d(
            real_data,
            imag_data,
            n=mb.const(val=shape, before_op=op),
            dim=mb.const(val=dim, before_op=op),
            norm=op.norm,
            before_op=op,
        )
    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_ifft")
def _lower_complex_ifft(op: Operation):
    real_data, imag_data = _fft_1d(
        op.data.real, op.data.imag, op.n, op.dim, op.norm, before_op=op, inverse=True
    )
    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_ifftn")
def _lower_complex_ifftn(op: Operation):
    real_data = op.data.real
    imag_data = op.data.imag
    shapes, dims = fft_canonicalize_shapes_dims(real_data, op.shapes, op.dims)
    for shape, dim in zip(shapes, dims):
        real_data, imag_data = _fft_1d(
            real_data,
            imag_data,
            n=mb.const(val=shape, before_op=op),
            dim=mb.const(val=dim, before_op=op),
            norm=op.norm,
            before_op=op,
            inverse=True,
        )
    return _wrap_complex_output(op.outputs[0], real_data, imag_data)


@LowerComplex.register_lower_func(op_type="complex_irfft")
def _lower_complex_irfft(op: Operation):
    real_data, imag_data = _restore_conj(op.data, op.n, op.dim, before_op=op)
    n, dim = fft_canonicalize_length_dim(op.data, op.n, op.dim, c2r=True)
    real_data, imag_data = _fft_1d(
        real_data,
        imag_data,
        mb.const(val=n, before_op=op),
        mb.const(val=dim, before_op=op),
        op.norm,
        before_op=op,
        inverse=True,
    )
    return real_data


@LowerComplex.register_lower_func(op_type="complex_irfftn")
def _lower_complex_irfftn(op: Operation):
    real_data = op.data.real
    imag_data = op.data.imag
    shapes, dims = fft_canonicalize_shapes_dims(real_data, op.shapes, op.dims, c2r=True)

    # For all but last dim/shape, do N-D IFFT.
    for shape, dim in zip(shapes[:-1], dims[:-1]):
        real_data, imag_data = _fft_1d(
            real_data,
            imag_data,
            n=mb.const(val=shape, before_op=op),
            dim=mb.const(val=dim, before_op=op),
            norm=op.norm,
            before_op=op,
            inverse=True,
        )

    # For the last dim/shape, do 1-D IRFFT.
    n: Var = mb.const(val=shapes[-1], before_op=op)
    dim: Var = mb.const(val=dims[-1], before_op=op)
    real_data, imag_data = _restore_conj(
        input_data=_wrap_complex_output(op.outputs[0], real_data, imag_data),
        n=n,
        dim=dim,
        before_op=op,
    )
    real_data, imag_data = _fft_1d(
        real_data, imag_data, n, dim, op.norm, before_op=op, inverse=True
    )
    real_data = _resize_data(real_data, dims=(dim.val,), sizes=(n.val,), before_op=op)

    return real_data


@LowerComplex.register_lower_func(op_type="complex_shape")
def _lower_complex_shape(op: Operation):
    return mb.shape(x=op.data.real, before_op=op)


def _match_and_replace_dialect_op(block, op):
    if not LowerComplex.has_lower_func(op.op_type):
        return False

    lower_res = LowerComplex.get_lower_func(op.op_type)(op)

    if not op.enclosing_block.try_replace_uses_of_var_after_op(
        anchor_op=op,
        old_var=op.outputs[0],
        new_var=lower_res,
    ):
        raise ValueError(f"Unable to lower complex dialect op {op}")
    block.remove_ops([op])
    return True


@block_context_manager
def _lower_complex_dialect_ops_in_block(block):
    def help_lower_complex_dialect_ops(block):
        for op in list(block.operations):
            if _match_and_replace_dialect_op(block, op):
                return True
        return False

    block_changed = True
    while block_changed:
        block_changed = help_lower_complex_dialect_ops(block)


@register_pass(namespace="common")
class lower_complex_dialect_ops(AbstractGraphPass):
    """
    Identify complex data related ops and replace it by using real and imaginary parts separately.
    The goal of this pass it to lower complex dialect ops into core ops.

    This pass also checks if the output is complex. As Core ML doesn't support complex data yet,
    it errors out early when detecting complex output.

    Input graph (`complex` and `complex_real` are complex dialect ops):
        %complex_data = complex(real_data=%real_data, imag_data=%imag_data)
        %real_data = complex_real(data=%complex_data)
        return %real_data

    Output graph (only core ops, no complex dialect ops):
        %complex_data_real = identity(x=%real_data)
        %complex_data_imag = identity(x=%imag_data)
        %real_data = identity(data=%complex_data_real)
        return %real_data
    """

    def apply(self, prog):
        for block in prog.functions.values():
            # Early error out for complex data output.
            for out_var in block.outputs:
                if types.is_complex(out_var.dtype):
                    raise ValueError(
                        "MIL doesn't support complex data as model's output, please "
                        "extract real and imaginary parts explicitly."
                    )

            _lower_complex_dialect_ops_in_block(block)
