Source code for forecopy.fun

import pandas as pd
import jax.numpy as jnp
import scipy.sparse as sps
import jax as jax
from itertools import chain
import lineax as lx
import jax.scipy.linalg as jsl

def factors(n):
    kset = list(chain.from_iterable([[i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0]))
    kset = list(set(kset))
    kset.sort(reverse=True)
    return kset

def vec2hmat(vec, h, kset):
    if len(vec.shape) == 2:
        vec = vec[0,:]

    m = max(kset)
    freq = m/jnp.array(kset)
    idx = jnp.tile(jnp.array([range(1, h+1)]), len(kset))[0, :]
    idx = jnp.repeat(idx, jnp.repeat(freq, h).astype(int))
    idx_sort = jnp.argsort(idx, stable=True)

    return vec[idx_sort].reshape(h, int(len(idx)/h))

def hmat2vec(hmat, kset):
    h = hmat.shape[0]
    m = max(kset)
    freq = m/jnp.array(kset)
    it = jnp.tile(jnp.repeat(freq, freq.astype(int)), h)

    return (hmat.T.flatten('F'))[jnp.argsort(it, stable=True)]

def FoReco2pd(reco: jnp.ndarray, type: str = "cs", labels: tuple = None):
    if labels is None:
        labels = [f"s-{i+1}" for i in range(reco.shape[1])]

    index = [f"h-{i+1}" for i in range(reco.shape[0])]

    df = pd.DataFrame(data=reco, index=index, columns=labels)

    return df

def isDiag(M):
    if len(M.shape) == 1:
        return True
    elif sps.issparse(M):
        i, j, _= sps.find(M)
        return jnp.all(i == j)
    else:
        i, j = M.shape
        assert i == j 
        test = M.reshape(-1)[:-1].reshape(i-1, j+1)
        return ~jnp.any(test[:, 1:])

def lin_sys(lhs, rhs, solver = 'default'):
    if solver == 'lineax':
        lhs = lx.MatrixLinearOperator(lhs, lx.positive_semidefinite_tag)
        solver = lx.Cholesky()
        state = solver.init(lhs, options={})
        def fun_solver(x):
            sol = solver.compute(state, x, options={})
            return sol[0]
        vmap_solver = jax.vmap(fun_solver, 1)
        return vmap_solver(rhs).T
    elif solver == 'cholesky':
        cho_factor = jsl.cho_factor(lhs, lower=True)
        return jsl.cho_solve(cho_factor, rhs)
    else:
        return jsl.solve(lhs, rhs)

def _mcrossprod(x):
    return jnp.dot(x.T,x)

def _covcor(cov):
    d = jnp.sqrt(jnp.diag(cov))
    corm = ((cov / d).T)/d
    corm = jnp.fill_diagonal(corm, 1.0, inplace = False)
    return corm

def sample_estim(x, mse=True):
    if jnp.any(jnp.isnan(x)):
        n = (~jnp.isnan(x)).sum(0)
        n = n * jnp.ones(len(n))[:, None]
        if not mse:
            x -= jnp.nanmean(x, 0)
        x[jnp.isnan(x)] = 0
        return (x.T @ x) * (1/jnp.minimum(n,n.T))
    else:
        if not mse:
            return jnp.cov(x.T)
        else:
            return jnp.dot(x.T, x) / x.shape[0]

[docs] def shrink_estim(x, mse=True): """ Shrinkage of the covariance matrix ---------------------------------- Shrinkage of the covariance matrix according to Schafer and Strimmer (2005). Parameters ---------- ``x``: ndarray A numeric matrix containing the residuals. ``mse``: bool, default True When `True`, residual moments are not mean-corrected (i.e., MSE rather than unbiased variance). When `False`, apply mean correction. Returns ------- A shrunk covariance matrix. References ---------- * Schafer, J.L. and Strimmer, K. (2005), A shrinkage approach to large-scale covariance matrix estimation and implications for functional genomics, `Statistical Applications in Genetics and Molecular Biology`, 4, 1 See Also -------- :func:`tecov <forecopy.cov.tecov>` :func:`cscov <forecopy.cov.cscov>` """ n = x.shape[0] covm = sample_estim(x, mse = mse) diag_covm = jnp.diag(covm) if n>3: xs = x/jnp.sqrt(diag_covm) vS = (1 / (n * (n - 1))) * (_mcrossprod(xs**2) - ((1 / n) * (_mcrossprod(xs)**2))) vS = jnp.fill_diagonal(vS, 0.0, inplace = False) corm = _covcor(covm) corm = jnp.fill_diagonal(corm, 0.0, inplace = False) corm = corm**2 lam = (jnp.sum(vS)/jnp.sum(corm)).item(0) lam = max(min(lam, 1), 0) else: lam = 1 shrink_cov = lam * jnp.diag(diag_covm) + (1 - lam) * covm return {'cov': shrink_cov, 'lambda': lam}
[docs] def is_PD(matrix: jnp.ndarray, tol: float = 1e-10) -> bool: """ Check if a matrix is positive definite. Parameters ---------- matrix : jnp.ndarray The matrix to validate. tol : float, default 1e-10 Tolerance for symmetry check. Returns ------- bool True if the matrix is a valid covariance matrix. """ # Check if matrix is square if len(matrix.shape) != 2 or matrix.shape[0] != matrix.shape[1]: return False # Check if matrix is symmetric if jnp.max(jnp.abs(matrix - matrix.T)) > tol: return False # Check if matrix is positive definite using Cholesky decomposition try: jnp.linalg.cholesky(matrix) return True except jax.linalg.LinAlgError: return False