import numpy as np


class MicroAccel:
    def __init__(
        self,
        max_decel_lead=-4.5,
        max_accel=1,
        max_decel_comfort=-1,
        max_decel_des=-0.4,
        max_decel=-3,
        v_limit=30,
        s0=18,
        k=1,
        in_head=2,
        long_time_head=5,
        gap_pen=1,
        gap_target=3,
        period=128,
        time_avg_acc=1,
        sim_step=0.05,
    ):
        self.max_accel = max_accel
        self.max_decel_des = max_decel_des
        self.max_decel_comfort = max_decel_comfort
        self.max_decel = max_decel
        self.max_decel_lead = max_decel_lead
        self.v_limit = v_limit
        self.long_time_head = long_time_head
        self.gap_pen = gap_pen
        self.gap_target = gap_target
        self.s0 = s0
        self.k = k
        self.in_head = in_head
        self.period = period
        self.sim_step = sim_step
        self.previous_dx = None
        self.previous_v_des = None
        self.previous_v_max = None
        self.previous_lead_vel_real = None
        self.previous_lead_acc = 0
        self.prev_vels = []
        self.prev_lead_acc = []
        self.time_avg_acc = time_avg_acc
        self.length_acc_avg = np.floor(self.time_avg_acc / self.sim_step)
        self.previous_headway = None
        self.average_step = np.floor(self.period / self.sim_step)
        self.no_initial_signal = None

    def get_main_accel(
        self, this_vel, lead_vel, headway, target_speed=-1, minicar=True,
    ):
        h1 = self.gap_target * this_vel

        # Update prev_vels
        if len(self.prev_vels) < self.average_step:
            self.prev_vels.append(lead_vel)
        else:
            self.prev_vels.pop(0)
            self.prev_vels.append(lead_vel)
            assert len(self.prev_vels) == self.average_step

        lead_vel = max(lead_vel, 0.1)

        # in order to deal with ZeroDivisionError
        if np.abs(headway) < 1e-3:
            headway = 1e-3

        # Avoiding loss of signal
        if headway >= 252 and self.previous_dx is None:
            self.no_initial_signal = 1
            # No recording before getting a signal
            self.prev_vels[-1] = 0
            assert np.count_nonzero(self.prev_vels) == 0
        elif headway < 252:
            self.no_initial_signal = 0
        elif headway >= 252:
            # computing new effective velocity
            self.prev_vels[-1] = 0
            lead_vel = self.previous_lead_vel_real
            if self.previous_dx < 220:
                headway = self.previous_dx + self.in_head * self.sim_step

        # Update previous_lead_vel_real
        self.previous_lead_vel_real = lead_vel

        # Update target speed
        ref_vel = 35
        if target_speed == -1:
            if len(self.prev_vels) == 0 or np.count_nonzero(self.prev_vels) == 0:
                target_speed = self.v_limit
            else:
                target_speed = sum(self.prev_vels) / np.count_nonzero(self.prev_vels)
            assert target_speed > 0
            target_speed = np.min(
                (
                    0.1 * np.max(0, 1.0 * (headway - max(h1, 30)) / max(1, this_vel))
                    ^ 2 + target_speed
                ),
                ref_vel,
            )  # Penalise too large headway, the constants can be modified
        else:
            if self.no_initial_signal == 0:
                target_speed = np.max(
                    np.min(target_speed, 1.2 * lead_vel), 0.8 * min(lead_vel, ref_vel)
                )

        # Create leader's accel #TODO can do a better smoothing
        if (
            len(self.prev_vels) <= 1
            or self.prev_vels[-1] == 0
            or self.prev_vels[-2] == 0
        ):
            lead_acc = 0
        elif (lead_vel - self.prev_vels[-2]) < 3.3 * self.sim_step:
            lead_acc = (lead_vel - self.prev_vels[-2]) / self.sim_step
        else:
            assert self.previous_lead_acc is not None
            lead_acc = self.previous_lead_acc

        self.previous_lead_acc = lead_acc  # Update lead_acc for next timestep

        # Create the acceleration
        # TODO add the patch for spotting the lack of signal.
        # TODO maybe issues when there is a lead veh then no lead veh then again a lead veh for the memory quantity

        # Create v_des from the motion planner

        v_des = target_speed
        assert v_des != -1 and v_des is not None
        if self.previous_v_des is None:  # Initial step
            v_des_dot = 0
        else:
            v_des_dot = (
                v_des - self.previous_v_des
            ) / self.sim_step  # Create the derivative of V_DES

        # update previous v_des
        self.previous_v_des = v_des

        if (
            lead_vel is None
            or (headway >= 252 and self.previous_dx > 220)
            or self.no_initial_signal == 1
        ):
            a_mng = self.max_accel
            v_max_dot = 0
            v_max = self.v_limit
            if lead_vel is None or self.no_initial_signal == 1:
                v_des = self.v_limit
                v_des_dot = 0
                self.previous_v_des = self.v_limit
        else:
            dx = headway
            if self.previous_dx is None:  # Initial step
                self.previous_dx = dx

            # Compute v_max and its derivative
            v_max = np.sqrt(
                2
                * np.abs(self.max_decel)
                * (
                    max(dx - self.s0, 0)
                    + 0.5 * ((lead_vel) ** 2) / np.abs(self.max_decel_lead)
                )
            )

            if self.previous_v_max is None:  # Initial step
                v_max_dot = 0
            elif (
                np.abs(dx - self.previous_dx) > self.sim_step * 50
            ):  # Exclude the discontinuities in headway
                v_max_dot = 0
            else:
                v_max_dot = (
                    v_max - self.previous_v_max
                ) / self.sim_step  # Compute the derivative of v_max

            self.previous_v_max = v_max  # update previous_v_max

            # Create an average of the lead acc

            if len(self.prev_lead_acc) < self.length_acc_avg:
                self.prev_lead_acc.append(lead_acc)
            else:
                self.prev_lead_acc.pop(0)
                self.prev_lead_acc.append(lead_acc)
                assert len(self.prev_lead_acc) == self.length_acc_avg
            if (
                len(self.prev_lead_acc) == 0
                or np.count_nonzero(self.prev_lead_acc) == 0
            ):
                lead_acc_avg = 0
            else:
                lead_acc_avg = sum(self.prev_lead_acc) / np.count_nonzero(
                    self.prev_lead_acc
                )

            # Compute a_mng (acceleration managment)

            if lead_acc_avg < 0:
                a_0 = lead_acc_avg * this_vel / (lead_vel + 0.001)
                a_12 = (
                    -0.5
                    * (this_vel) ** 2
                    / (
                        max(0, dx - self.s0)
                        + 0.5 * ((lead_vel + 0.00001) ** 2) / abs(lead_acc_avg - 0.01)
                    )
                )
                if a_0 < a_12:
                    a_mng = a_12
                else:
                    if lead_vel >= this_vel:
                        a_mng = a_0
                    else:
                        a_mng = lead_acc_avg - ((lead_vel - this_vel) ** 2) / (
                            2 * max(dx - self.s0, 0.0001)
                        )
                    assert a_mng <= 0, (a_mng, a_0)  # changed to <= 0
            elif lead_acc_avg >= 0:
                if lead_vel <= this_vel:
                    a_mng = lead_acc_avg - ((max(0, this_vel - lead_vel)) ** 2) / (
                        2 * np.max(dx - self.s0, 0.001)
                    )
                else:
                    a_mng = np.min(
                        self.max_accel,
                        lead_acc_avg + (lead_acc_avg) * (lead_vel - this_vel),
                    )
                    # TODO this can be changed / might be the source of some oscillations. together with long time head
            else:
                a_mng = None
                print("This is an error")

            # Unactivate a_mng when the vehicle is above a certain distance
            assert self.long_time_head > 1 and self.long_time_head >= self.gap_target

            if dx > this_vel * (self.long_time_head - 1):
                a_mng = a_mng + self.gap_pen * (
                    dx - (self.long_time_head - 1) * this_vel
                ) / max(this_vel, 0.01)

            self.previous_dx = dx  # Update the previous headway

        # Final acceleration
        a_mng = np.max(a_mng, -np.abs(self.max_decel_comfort))

        a_vdes = np.max(
            -self.k * (this_vel - v_des) + v_des_dot, -np.abs(self.max_decel_des)
        )
        a_vmax = -self.k * (this_vel - v_max) + v_max_dot
        accel = min(min(a_vdes, a_vmax), a_mng)
        accel = min(max(accel, -np.abs(self.max_decel)), self.max_accel)

        if this_vel >= self.v_limit:
            accel = min(0, accel)

        return accel
