Module pycpd.constrained_deformable_registration

Expand source code
from builtins import super
import numpy as np
import numbers
from .deformable_registration import DeformableRegistration


class ConstrainedDeformableRegistration(DeformableRegistration):
    """
    Constrained deformable registration.

    Attributes
    ----------
    alpha: float (positive)
        Represents the trade-off between the goodness of maximum likelihood fit and regularization.

    beta: float(positive)
        Width of the Gaussian kernel.

    e_alpha: float (positive)
        Reliability of correspondence priors. Between 1e-8 (very reliable) and 1 (very unreliable)
    
    source_id: numpy.ndarray (int) 
        Indices for the points to be used as correspondences in the source array

    target_id: numpy.ndarray (int) 
        Indices for the points to be used as correspondences in the target array

    """

    def __init__(self, e_alpha = None, source_id = None, target_id= None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if e_alpha is not None and (not isinstance(e_alpha, numbers.Number) or e_alpha <= 0):
            raise ValueError(
                "Expected a positive value for regularization parameter e_alpha. Instead got: {}".format(e_alpha))
        
        if type(source_id) is not np.ndarray or source_id.ndim != 1:
            raise ValueError(
                "The source ids (source_id) must be a 1D numpy array of ints.")
        
        if type(target_id) is not np.ndarray or target_id.ndim != 1:
            raise ValueError(
                "The target ids (target_id) must be a 1D numpy array of ints.")

        self.e_alpha = 1e-8 if e_alpha is None else e_alpha
        self.source_id = source_id
        self.target_id = target_id
        self.P_tilde = np.zeros((self.M, self.N))
        self.P_tilde[self.source_id, self.target_id] = 1
        self.P1_tilde = np.sum(self.P_tilde, axis=1)
        self.PX_tilde = np.dot(self.P_tilde, self.X)

    def update_transform(self):
        """
        Calculate a new estimate of the deformable transformation.
        See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.

        """
        if self.low_rank is False:
            A = np.dot(np.diag(self.P1), self.G) + \
                self.sigma2*(1/self.e_alpha)*np.dot(np.diag(self.P1_tilde), self.G) + \
                self.alpha * self.sigma2 * np.eye(self.M)
            B = self.PX - np.dot(np.diag(self.P1), self.Y) + self.sigma2*(1/self.e_alpha)*(self.PX_tilde - np.dot(np.diag(self.P1_tilde), self.Y)) 
            self.W = np.linalg.solve(A, B)

        elif self.low_rank is True:
            # Matlab code equivalent can be found here:
            # https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
            dP = np.diag(self.P1) + self.sigma2*(1/self.e_alpha)*np.diag(self.P1_tilde)
            dPQ = np.matmul(dP, self.Q)
            F = self.PX - np.dot(np.diag(self.P1), self.Y) + self.sigma2*(1/self.e_alpha)*(self.PX_tilde - np.dot(np.diag(self.P1_tilde), self.Y)) 

            self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
                np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
                                (np.matmul(self.Q.T, F))))))
            QtW = np.matmul(self.Q.T, self.W)
            self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))

Classes

class ConstrainedDeformableRegistration (e_alpha=None, source_id=None, target_id=None, *args, **kwargs)

Constrained deformable registration.

Attributes

alpha : float (positive)
Represents the trade-off between the goodness of maximum likelihood fit and regularization.
beta : float(positive)
Width of the Gaussian kernel.
e_alpha : float (positive)
Reliability of correspondence priors. Between 1e-8 (very reliable) and 1 (very unreliable)
source_id : numpy.ndarray (int)
Indices for the points to be used as correspondences in the source array
target_id : numpy.ndarray (int)
Indices for the points to be used as correspondences in the target array
Expand source code
class ConstrainedDeformableRegistration(DeformableRegistration):
    """
    Constrained deformable registration.

    Attributes
    ----------
    alpha: float (positive)
        Represents the trade-off between the goodness of maximum likelihood fit and regularization.

    beta: float(positive)
        Width of the Gaussian kernel.

    e_alpha: float (positive)
        Reliability of correspondence priors. Between 1e-8 (very reliable) and 1 (very unreliable)
    
    source_id: numpy.ndarray (int) 
        Indices for the points to be used as correspondences in the source array

    target_id: numpy.ndarray (int) 
        Indices for the points to be used as correspondences in the target array

    """

    def __init__(self, e_alpha = None, source_id = None, target_id= None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if e_alpha is not None and (not isinstance(e_alpha, numbers.Number) or e_alpha <= 0):
            raise ValueError(
                "Expected a positive value for regularization parameter e_alpha. Instead got: {}".format(e_alpha))
        
        if type(source_id) is not np.ndarray or source_id.ndim != 1:
            raise ValueError(
                "The source ids (source_id) must be a 1D numpy array of ints.")
        
        if type(target_id) is not np.ndarray or target_id.ndim != 1:
            raise ValueError(
                "The target ids (target_id) must be a 1D numpy array of ints.")

        self.e_alpha = 1e-8 if e_alpha is None else e_alpha
        self.source_id = source_id
        self.target_id = target_id
        self.P_tilde = np.zeros((self.M, self.N))
        self.P_tilde[self.source_id, self.target_id] = 1
        self.P1_tilde = np.sum(self.P_tilde, axis=1)
        self.PX_tilde = np.dot(self.P_tilde, self.X)

    def update_transform(self):
        """
        Calculate a new estimate of the deformable transformation.
        See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf.

        """
        if self.low_rank is False:
            A = np.dot(np.diag(self.P1), self.G) + \
                self.sigma2*(1/self.e_alpha)*np.dot(np.diag(self.P1_tilde), self.G) + \
                self.alpha * self.sigma2 * np.eye(self.M)
            B = self.PX - np.dot(np.diag(self.P1), self.Y) + self.sigma2*(1/self.e_alpha)*(self.PX_tilde - np.dot(np.diag(self.P1_tilde), self.Y)) 
            self.W = np.linalg.solve(A, B)

        elif self.low_rank is True:
            # Matlab code equivalent can be found here:
            # https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift
            dP = np.diag(self.P1) + self.sigma2*(1/self.e_alpha)*np.diag(self.P1_tilde)
            dPQ = np.matmul(dP, self.Q)
            F = self.PX - np.dot(np.diag(self.P1), self.Y) + self.sigma2*(1/self.e_alpha)*(self.PX_tilde - np.dot(np.diag(self.P1_tilde), self.Y)) 

            self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, (
                np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
                                (np.matmul(self.Q.T, F))))))
            QtW = np.matmul(self.Q.T, self.W)
            self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))

Ancestors

Inherited members