SI for Feature Selection after Optimal Transport-based Domain Adaptation

This example demonstrates how to perform statistical inference for feature selection after applying Optimal Transport-based Domain Adaptation. The implementation is based on the paper by Loi et al. (2025) [3]. [3] Loi, N. T., Loc, D. T., & Duy, V. N. L. (2025). “Statistical Inference for Feature Selection after Optimal Transport-based Domain Adaptation.” In International Conference on Artificial Intelligence and Statistics, pp. 1747-1755. PMLR, 2025.

# Author: Tran Tuan Kiet

from pythonsi import Pipeline
from pythonsi.feature_selection import LassoFeatureSelection
from pythonsi import Data
from pythonsi.test_statistics import SFS_DATestStatistic
from pythonsi.domain_adaptation import OptimalTransportDA
import numpy as np
import matplotlib.pyplot as plt

Define the pipeline

def SFS_DA() -> Pipeline:
    xs = Data()
    ys = Data()

    xt = Data()
    yt = Data()

    OT = OptimalTransportDA()
    x_tilde, y_tilde = OT.run(xs=xs, ys=ys, xt=xt, yt=yt)

    lasso = LassoFeatureSelection(lambda_=10)
    active_set = lasso.run(x_tilde, y_tilde)
    return Pipeline(
        inputs=(xs, ys, xt, yt),
        output=active_set,
        test_statistic=SFS_DATestStatistic(xs=xs, ys=ys, xt=xt, yt=yt),
    )


my_pipeline = SFS_DA()

Generate data

def gen_data(n, p, true_beta):
    x = np.random.normal(loc=0, scale=1, size=(n, p))
    true_beta = true_beta.reshape(-1, 1)

    mu = x.dot(true_beta)
    Sigma = np.identity(n)
    Y = mu + np.random.normal(loc=0, scale=1, size=(n, 1))
    return x, Y, Sigma


xs, ys, sigma_s = gen_data(150, 5, np.asarray([0, 0, 0, 0, 0]))
xt, yt, sigma_t = gen_data(25, 5, np.asarray([0, 0, 0, 0, 0]))

Run the pipeline

selected_features, p_values = my_pipeline(
    inputs=[xs, ys, xt, yt], covariances=[sigma_s, sigma_t]
)

print("Selected features: ", selected_features)
print("P-values: ", p_values)
Selected features:  [0 1 2 3]
P-values:  [0.5424014971318711, 0.21260386674254184, 0.15564751961626722, 0.41546429041302224]

Plot the p-values

plt.figure()
plt.bar(range(len(p_values)), p_values)
plt.xlabel("Feature index")
plt.ylabel("P-value")
plt.show()
SFS DA

Gallery generated by Sphinx-Gallery