""" Parameter transforms
This module contains the different transforms that can be applied to a parameter.
The transforms are used to transform a parameter value θ to the unconstrained space
η and vice versa. The transforms are used to ensure that the parameters are
constrained to a certain range and that the optimization algorithm can find the
optimal parameters.
The transforms are defined in the class ParameterTransform. The class has the
following abstract attributes / methods:
    - name: Name of the transform
    - transform: Transform a parameter value θ to the unconstrained space η
    - untransform: Transform a parameter value η to the constrained space θ
    - grad_transform: Gradient of the transform function
    - grad_untransform: Gradient of the untransform function
The transforms are registered in the Transforms namespace. The namespace is used
to get the transform class from the name of the transform.
The following transforms are available:
    - None: No transform
    - Fixed: Fixed parameter
    - Log: Log transform
    - Lower: Lower bound
    - Upper: Upper bound
    - Logit: Logit transform
An auto transform is also available. The auto transform will select the best
transform based on the bounds of the parameter.
"""
from abc import ABCMeta, abstractmethod
import numpy as np
from typing_extensions import Self
class Transforms:
    @classmethod
    def register(cls, transform, name=None):
        if name is None:
            name = getattr(transform, "name", transform.__qualname__)
        if hasattr(cls, name):
            raise ValueError("Transform with that name already exists")
        setattr(cls, name, transform)
    @classmethod
    def get(cls, name):
        return getattr(cls, name)
# decorator to register transforms
def register_transform(cls):
    Transforms.register(cls)
    return cls
class ParameterTransform(metaclass=ABCMeta):
    def __init__(self, bounds: tuple):
        self.lb, self.ub = bounds
    @property
    @abstractmethod
    def name(self) -> str:
        pass
    @abstractmethod
    def transform(self, θ: float) -> float:
        """Transform a parameter value θ to the unconstrained space η
        Parameters
        ----------
        θ : float
            Parameter value in the constrained space θ
        Returns
        -------
        η : float
            Parameter value in the unconstrained space η
        """
        pass
    @abstractmethod
    def untransform(self, η: float) -> float:
        """Transform a parameter value η to the constrained space θ
        Parameters
        ----------
        η : float
            Parameter value in the unconstrained space η
        Returns
        -------
        θ : float
            Parameter value in the constrained space θ
        """
        pass
    @abstractmethod
    def grad_transform(self, θ: float) -> float:
        """Gradient of the transform function
        Parameters
        ----------
        θ : float
            Parameter value in the constrained space θ
        Returns
        -------
        grad : float
            Gradient of the transform function
        """
        pass
    @abstractmethod
    def grad_untransform(self, η: float) -> float:
        """Gradient of the untransform function
        Parameters
        ----------
        η : float
            Parameter value in the unconstrained space η
        Returns
        -------
        grad : float
            Gradient of the untransform function
        """
        pass
    @abstractmethod
    def penalty(self, θ: float) -> float:
        """Penalty for the parameter value θ
        Parameters
        ----------
        θ : float
            Parameter value in the constrained space θ
        Returns
        -------
        penalty : float
            Penalty for the parameter value θ
        """
        pass
    @abstractmethod
    def grad_penalty(self, θ: float) -> float:
        """Gradient of the penalty function for the parameter value θ
        Parameters
        ----------
        θ : float
            Parameter value in the constrained space θ
        Returns
        -------
        grad_penalty : float
            Gradient of the penalty function for the parameter value θ
        """
        pass
    def in_bounds(self, x: float) -> bool:
        """Check if the parameter value is in the bounds of the transform
        Parameters
        ----------
        x : float
            Parameter value in the constrained space θ
        Returns
        -------
        in_bounds : bool
            True if the parameter value is in the bounds of the transform
        """
        return True
    def __repr__(self):
        return f"{self.name}"
    def __eq__(self, __value: Self) -> bool:
        if isinstance(__value, ParameterTransform):
            return self.name == __value.name
        return False