Source code for pythonsi.domain_adaptation.optimal_transport

import numpy as np
import numpy.typing as npt
from pythonsi.node import Data
from typing import Tuple
from pythonsi.util import solve_quadratic_inequality, intersect
from scipy.cluster.hierarchy import DisjointSet
import ot


def construct_Theta(ns, nt):
    return np.hstack(
        (
            np.kron(np.identity(ns), np.ones((nt, 1))),
            np.kron(-np.ones((ns, 1)), np.identity(nt)),
        )
    )


def construct_cost(xs, ys, xt, yt):
    xs_squared = np.sum(xs**2, axis=1, keepdims=True)  # shape (n_s, 1)
    xt_squared = np.sum(xt**2, axis=1, keepdims=True).T  # shape (1, n_t)
    cross_term = xs @ xt.T  # shape (n_s, n_t)

    c_ = xs_squared - 2 * cross_term + xt_squared

    ys_squared = np.sum(ys**2, axis=1, keepdims=True)  # shape (n_s, 1)
    yt_squared = np.sum(yt**2, axis=1, keepdims=True).T  # shape (1, n_t)
    cross_term = ys @ yt.T  # shape (n_s, n_t)

    c__ = ys_squared - 2 * cross_term + yt_squared
    c = c_ + c__
    return c_.reshape(-1, 1), c.reshape(-1, 1)


def construct_H(ns, nt):
    Hr = np.zeros((ns, ns * nt))

    for i in range(ns):
        Hr[i : i + 1, i * nt : (i + 1) * nt] = np.ones((1, nt))

    Hc = np.identity(nt)
    for _ in range(ns - 1):
        Hc = np.hstack((Hc, np.identity(nt)))

    H = np.vstack((Hr, Hc))
    H = H[:-1, :]
    return H


def construct_h(ns, nt):
    h = np.vstack((np.ones((ns, 1)) / ns, np.ones((nt, 1)) / nt))
    h = h[:-1, :]
    return h


def construct_B(T, u, v, c):
    ns, nt = T.shape
    DJ = DisjointSet(range(ns + nt))
    B = []

    # Vectorized first loop - process elements where T > 0
    large_T_indices = np.where(T > 0)
    for i, j in zip(large_T_indices[0], large_T_indices[1]):
        DJ.merge(i, j + ns)
        B.append(i * nt + j)

    # Early exit if we already have enough elements
    if len(B) >= ns + nt - 1:
        return sorted(B[: ns + nt - 1])

    # Vectorized computation of reduced costs
    rc = c - u[:, np.newaxis] - v[np.newaxis, :]

    # Find candidates with smallest |rc|
    flat_rc = np.abs(rc).flatten()
    sorted_indices = np.argsort(flat_rc)

    # Process candidates in order of smallest reduced cost
    for idx in sorted_indices:
        i, j = divmod(idx, nt)
        if len(B) >= ns + nt - 1:
            break
        if not DJ.connected(i, j + ns):
            DJ.merge(i, j + ns)
            B.append(i * nt + j)

    return sorted(B)


[docs] class OptimalTransportDA: r"""Optimal Transport Domain Adaptation with selective inference support. The optimal transport problem solved is: .. math:: \min_{T \in \mathcal{P}} \langle C, T \rangle where :math:`\mathcal{P}` is the set of transport plans with given marginals and :math:`C` is the cost matrix between domains. Attributes ---------- x_source_node : Data or None Source domain feature node y_source_node : Data or None Source domain label node x_target_node : Data or None Target domain feature node y_target_node : Data or None Target domain label node x_output_node : Data Adapted feature output node y_output_node : Data Adapted label output node interval : list or None Feasible interval for the last inference call x_output_data : array-like or None Stored adapted features from last inference call y_output_data : tuple or None Stored adapted labels from last inference call """ def __init__(self): self.x_source_node = None self.y_source_node = None self.x_target_node = None self.y_target_node = None self.x_output_node = Data(self) self.y_output_node = Data(self) self.interval = None self.x_output_data = None self.y_output_data = None
[docs] def run( self, xs: Data, ys: Data, xt: Data, yt: Data, ) -> Data: r"""Configure domain adaptation with input data. Parameters ---------- xs : array-like, shape (ns, d) Source domain features ys : array-like, shape (ns, 1) Source domain labels xt : array-like, shape (nt, d) Target domain features yt : array-like, shape (nt, 1) Target domain labels Returns ------- x_output_node : Data Node containing adapted features y_output_node : Data Node containing adapted labels Examples -------- >>> ot_da = OptimalTransportDA() >>> x_out, y_out = ot_da.run(xs, ys, xt, yt) >>> adapted_x = x_out() """ self.x_source_node = xs self.y_source_node = ys self.x_target_node = xt self.y_target_node = yt return self.x_output_node, self.y_output_node
[docs] def forward( self, xs: npt.NDArray[np.floating], ys: npt.NDArray[np.floating], xt: npt.NDArray[np.floating], yt: npt.NDArray[np.floating], ): r"""Solve optimal transport and construct adapted dataset. Parameters ---------- xs : array-like, shape (ns, d) Source domain features ys : array-like, shape (ns, 1) Source domain labels xt : array-like, shape (nt, d) Target domain features yt : array-like, shape (nt, 1) Target domain labels Returns ------- x_tilde : array-like, shape (ns+nt, d) Adapted feature matrix y_tilde : array-like, shape (ns+nt, 1) Adapted label vector B : list of int Basic feasible solution indices c_features : array-like, shape (ns*nt, 1) Feature space cost matrix Omega : array-like, shape (ns+nt, ns+nt) Transformation matrix for adaptation Notes ----- The adapted dataset is constructed as: .. math:: \tilde{\mathbf{x}} = \Omega \begin{bmatrix} \mathbf{x}_s \\ \mathbf{x}_t \end{bmatrix} \tilde{\mathbf{y}} = \Omega \begin{bmatrix} \mathbf{y}_s \\ \mathbf{y}_t \end{bmatrix} where :math:`\Omega` incorporates the optimal transport plan. """ x = np.vstack((xs, xt)) y = np.vstack((ys, yt)) ns = xs.shape[0] nt = xt.shape[0] row_mass = np.ones(ns) / ns col_mass = np.ones(nt) / nt _c, c = construct_cost(xs, ys, xt, yt) T, log = ot.emd(a=row_mass, b=col_mass, M=c.reshape(ns, nt), log=True) B = np.where(T.reshape(-1) != 0)[0].tolist() if len(B) != ns + nt - 1: B = construct_B(T, log["u"], log["v"], c.reshape(ns, nt)) T = T.reshape(ns, nt) Omega = np.hstack( (np.zeros((ns + nt, ns)), np.vstack((ns * T, np.identity(nt)))) ) x_tilde = Omega.dot(x) y_tilde = Omega.dot(y) return x_tilde, y_tilde, B, _c, Omega
def __call__(self) -> Tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]: r"""Execute domain adaptation on stored data. Returns ------- x_tilde : array-like, shape (ns+nt, d) Adapted feature matrix y_tilde : array-like, shape (ns+nt, 1) Adapted label vector Examples -------- >>> ot_da = OptimalTransportDA() >>> # ... set up data nodes ... >>> x_adapted, y_adapted = ot_da() """ xs = self.x_source_node() ys = self.y_source_node() xt = self.x_target_node() yt = self.y_target_node() x_tilde, y_tilde, _, _, _ = self.forward(xs, ys, xt, yt) self.x_output_node.update(x_tilde) self.y_output_node.update(y_tilde) return x_tilde, y_tilde
[docs] def inference(self, z: float) -> Tuple[list, npt.NDArray[np.floating]]: r"""Find feasible interval of the Optimal Transport for the parametrized data at z . Parameters ---------- z : float Scalar parameter Returns ------- final_interval : list Feasible interval [lower, upper] for z """ if self.interval is not None and self.interval[0] <= z <= self.interval[1]: self.x_output_node.parametrize(data=self.x_output_data) self.y_output_node.parametrize( a=self.y_output_data[0], b=self.y_output_data[1], data=self.y_output_data[2], ) return self.interval xs, _, _, interval_xs = self.x_source_node.inference(z) ys, a_ys, b_ys, interval_ys = self.y_source_node.inference(z) xt, _, _, interval_xt = self.x_target_node.inference(z) yt, a_yt, b_yt, interval_yt = self.y_target_node.inference(z) _, _, B, c_, Omega = self.forward(xs, ys, xt, yt) x = np.vstack((xs, xt)) y = np.vstack((ys, yt)) a = np.vstack((a_ys, a_yt)) b = np.vstack((b_ys, b_yt)) ns = xs.shape[0] nt = xt.shape[0] Bc = list(set(range(ns * nt)) - set(B)) H = construct_H(ns, nt) Theta = construct_Theta(ns, nt) Theta_a = Theta.dot(a) Theta_b = Theta.dot(b) p_tilde = c_ + Theta_a * Theta_a q_tilde = 2 * Theta_a * Theta_b r_tilde = Theta_b * Theta_b HB_invHBc = np.linalg.inv(H[:, B]).dot(H[:, Bc]) p = (p_tilde[Bc, :].T - p_tilde[B, :].T.dot(HB_invHBc)).T q = (q_tilde[Bc, :].T - q_tilde[B, :].T.dot(HB_invHBc)).T r = (r_tilde[Bc, :].T - r_tilde[B, :].T.dot(HB_invHBc)).T final_interval = [-np.inf, np.inf] for i in range(p.shape[0]): fa = -r[i][0] sa = -q[i][0] ta = -p[i][0] temp = solve_quadratic_inequality(fa, sa, ta, z) final_interval = intersect(final_interval, temp) final_interval = intersect(final_interval, interval_xs) final_interval = intersect(final_interval, interval_ys) final_interval = intersect(final_interval, interval_xt) final_interval = intersect(final_interval, interval_yt) x_tilde = Omega.dot(x) y_tilde = Omega.dot(y) a_tilde = Omega.dot(a) b_tilde = Omega.dot(b) self.x_output_node.parametrize(data=x_tilde) self.y_output_node.parametrize(a=a_tilde, b=b_tilde, data=y_tilde) self.interval = final_interval self.x_output_data = x_tilde self.y_output_data = (a_tilde, b_tilde, y_tilde) return self.interval