Skip to content

Module hetmatpy.degree_weight

None

None

View Source
import collections

import copy

import functools

import itertools

import logging

import numpy

from hetnetpy.matrix import sparsify_or_densify

from scipy import sparse

import hetmatpy.hetmat

import hetmatpy.matrix

from hetmatpy.hetmat.caching import path_count_cache

def _category_to_function(category, dwwc_method):

    function_dictionary = {

        "no_repeats": dwwc_method,

        "disjoint": _dwpc_disjoint,

        "disjoint_groups": _dwpc_disjoint,

        "short_repeat": _dwpc_short_repeat,

        "four_repeat": _dwpc_baba,

        "long_repeat": _dwpc_general_case,

        "BAAB": _dwpc_baab,

        "BABA": _dwpc_baba,

        "repeat_around": _dwpc_repeat_around,

        "interior_complete_group": _dwpc_baba,

        "other": _dwpc_general_case,

    }

    return function_dictionary[category]

@path_count_cache(metric="dwpc")

def dwpc(

    graph,

    metapath,

    damping=0.5,

    dense_threshold=0,

    approx_ok=False,

    dtype=numpy.float64,

    dwwc_method=None,

):

    """

    A unified function to compute the degree-weighted path count.

    This function will call get_segments, then the appropriate

    specialized (or generalized) DWPC function.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold above which a sparse matrix will be

        converted to a dense automatically.

    approx_ok : bool

        if True, uses an approximation to DWPC. If False, dwpc will call

        _dwpc_general_case and give a warning on metapaths which are

        categorized 'other' and 'long_repeat'..

    dtype : dtype object

        numpy.float32 or numpy.float64. At present, numpy.float16 fails when

        using sparse matrices, due to a bug in scipy.sparse

    dwwc_method : function

        dwwc method to use for computing DWWCs. If set to None, use

        module-level default (default_dwwc_method).

    Returns

    -------

    numpy.ndarray

        row labels

    numpy.ndarray

        column labels

    numpy.ndarray or scipy.sparse.csc_matrix

        the DWPC matrix

    """

    category = categorize(metapath)

    dwpc_function = _category_to_function(category, dwwc_method=dwwc_method)

    if category in ("long_repeat", "other"):

        if approx_ok:

            dwpc_function = _dwpc_approx

        else:

            logging.warning(

                f"Metapath {metapath} will use _dwpc_general_case, "

                "which can require very long computations."

            )

    row_names, col_names, dwpc_matrix = dwpc_function(

        graph, metapath, damping, dense_threshold=dense_threshold, dtype=dtype

    )

    return row_names, col_names, dwpc_matrix

@path_count_cache(metric="dwwc")

def dwwc(

    graph,

    metapath,

    damping=0.5,

    dense_threshold=0,

    dtype=numpy.float64,

    dwwc_method=None,

):

    """

    Compute the degree-weighted walk count (DWWC) in which nodes can be

    repeated within a path.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold at which a sparse matrix will be

        converted to a dense automatically.

    dtype : dtype object

    dwwc_method : function

        dwwc method to use for computing DWWCs. If set to None, use

        module-level default (default_dwwc_method).

    """

    return dwwc_method(

        graph=graph,

        metapath=metapath,

        damping=damping,

        dense_threshold=dense_threshold,

        dtype=dtype,

    )

def dwwc_sequential(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    Compute the degree-weighted walk count (DWWC) in which nodes can be

    repeated within a path.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold at which a sparse matrix will be

        converted to a dense automatically.

    dtype : dtype object

    """

    dwwc_matrix = None

    row_names = None

    for metaedge in metapath:

        rows, cols, adj_mat = hetmatpy.matrix.metaedge_to_adjacency_matrix(

            graph, metaedge, dense_threshold=dense_threshold, dtype=dtype

        )

        adj_mat = _degree_weight(adj_mat, damping, dtype=dtype)

        if dwwc_matrix is None:

            row_names = rows

            dwwc_matrix = adj_mat

        else:

            dwwc_matrix = dwwc_matrix @ adj_mat

            dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return row_names, cols, dwwc_matrix

def dwwc_recursive(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    Recursive DWWC implementation to take better advantage of caching.

    """

    rows, cols, adj_mat = hetmatpy.matrix.metaedge_to_adjacency_matrix(

        graph, metapath[0], dense_threshold=dense_threshold, dtype=dtype

    )

    adj_mat = _degree_weight(adj_mat, damping, dtype=dtype)

    if len(metapath) > 1:

        _, cols, dwwc_next = dwwc(

            graph,

            metapath[1:],

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

            dwwc_method=dwwc_recursive,

        )

        dwwc_matrix = adj_mat @ dwwc_next

    else:

        dwwc_matrix = adj_mat

    dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return rows, cols, dwwc_matrix

def _multi_dot(metapath, order, i, j, graph, damping, dense_threshold, dtype):

    """

    Perform matrix multiplication with the given order. Modified from

    numpy.linalg.linalg._multi_dot (https://git.io/vh31f) which is released

    under a 3-Clause BSD License (https://git.io/vhCDC).

    """

    if i == j:

        _, _, adj_mat = hetmatpy.matrix.metaedge_to_adjacency_matrix(

            graph, metapath[i], dense_threshold=dense_threshold, dtype=dtype

        )

        adj_mat = _degree_weight(adj_mat, damping=damping, dtype=dtype)

        return adj_mat

    return _multi_dot(

        metapath, order, i, order[i, j], graph, damping, dense_threshold, dtype

    ) @ _multi_dot(

        metapath, order, order[i, j] + 1, j, graph, damping, dense_threshold, dtype

    )

def _dimensions_to_ordering(dimensions):

    # Find optimal matrix chain ordering. See https://git.io/vh38o

    n = len(dimensions) - 1

    m = numpy.zeros((n, n), dtype=numpy.double)

    ordering = numpy.empty((n, n), dtype=numpy.intp)

    for l_ in range(1, n):

        for i in range(n - l_):

            j = i + l_

            m[i, j] = numpy.inf

            for k in range(i, j):

                q = (

                    m[i, k]

                    + m[k + 1, j]

                    + dimensions[i] * dimensions[k + 1] * dimensions[j + 1]

                )

                if q < m[i, j]:

                    m[i, j] = q

                    ordering[i, j] = k

    return ordering

def dwwc_chain(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    Uses optimal matrix chain multiplication as in numpy.multi_dot, but allows

    for sparse matrices. Uses ordering modified from numpy.linalg.linalg._multi_dot

    (https://git.io/vh31f) which is released under a 3-Clause BSD License

    (https://git.io/vhCDC).

    """

    metapath = graph.metagraph.get_metapath(metapath)

    array_dims = [graph.count_nodes(mn) for mn in metapath.get_nodes()]

    row_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.source())

    columns_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.target())

    ordering = _dimensions_to_ordering(array_dims)

    dwwc_matrix = _multi_dot(

        metapath, ordering, 0, len(metapath) - 1, graph, damping, dense_threshold, dtype

    )

    dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return row_ids, columns_ids, dwwc_matrix

def categorize(metapath):

    """

    Returns the classification of a given metapath as one of

    a set of metapath types which we approach differently.

    Parameters

    ----------

    metapath : hetnetpy.hetnet.MetaPath

    Returns

    -------

    classification : string

        One of ['no_repeats', 'disjoint', 'short_repeat',

                'long_repeat', 'BAAB', 'BABA', 'repeat_around',

                 'interior_complete_group', 'disjoint_groups', 'other']

    Examples

    --------

    GbCtDlA -> 'no_repeats'

    GiGiG   -> 'short_repeat'

    GiGiGcG -> 'four_repeat'

    GiGcGiGiG -> 'long_repeat'

    GiGbCrC -> 'disjoint'

    GbCbGbC -> 'BABA'

    GbCrCbG -> 'BAAB'

    DaGiGbCrC -> 'disjoint'

    GiGaDpCrC -> 'disjoint'

    GiGbCrCpDrD -> 'disjoint'

    GbCpDaGbCpD -> 'other'

    GbCrCrCrCrCbG -> 'other'

    """

    metanodes = list(metapath.get_nodes())

    freq = collections.Counter(metanodes)

    repeated = {metanode for metanode, count in freq.items() if count > 1}

    if not repeated:

        return "no_repeats"

    repeats_only = [node for node in metanodes if node in repeated]

    # Group neighbors if they are the same

    grouped = [list(v) for k, v in itertools.groupby(repeats_only)]

    # Handle multiple disjoint repeats, any number, ie. AA,BB,CC,DD,...

    if len(grouped) == len(repeated):

        # Identify if there is only one metanode

        if len(repeated) == 1:

            if max(freq.values()) < 4:

                return "short_repeat"

            elif max(freq.values()) == 4:

                return "four_repeat"

            else:

                return "long_repeat"

        return "disjoint"

    assert len(repeats_only) > 3

    # Categorize the reformatted metapath

    if len(repeats_only) == 4:

        if repeats_only[0] == repeats_only[-1]:

            assert repeats_only[1] == repeats_only[2]

            return "BAAB"

        else:

            assert (

                repeats_only[0] == repeats_only[2]

                and repeats_only[1] == repeats_only[3]

            )

            return "BABA"

    elif len(repeats_only) == 5 and max(map(len, grouped)) == 3:

        if repeats_only[0] == repeats_only[-1]:

            return "BAAB"

    elif repeats_only == list(reversed(repeats_only)) and not len(repeats_only) % 2:

        return "BAAB"

    # 6 node paths with 3x2 repeats

    elif len(repeated) == 3 and len(metapath) == 5:

        if repeats_only[0] == repeats_only[-1]:

            return "repeat_around"

        # AABCCB or AABCBC

        elif len(grouped[0]) == 2 or len(grouped[-1]) == 2:

            return "disjoint_groups"

        # ABA CC B

        elif len(repeats_only) - len(grouped) == 1:

            return "interior_complete_group"

        # most complicated len 6

        else:

            return "other"

    else:

        # Multi-repeats that aren't disjoint, eg. ABCBAC

        if len(repeated) > 2:

            logging.info(

                f"{metapath}: Only two overlapping repeats currently supported"

            )

            return "other"

        if len(metanodes) > 4:

            logging.info(

                f"{metapath}: Complex metapaths of length > 4 are not yet " f"supported"

            )

            return "other"

        assert False

def get_segments(metagraph, metapath):

    """

    Split a metapath into segments of recognized groups and non-repeated

    nodes. Groups include BAAB, BABA, disjoint short- and long-repeats.

    Returns an error for categorization 'other'.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapath : hetnetpy.hetnet.Metapath

    Returns

    -------

    list

        list of metapaths. If the metapath is not segmentable or is already

        fully simplified (eg. GiGaDaG), then the list will have only one

        element.

    Examples

    --------

    'CbGaDaGaD' -> ['CbG', 'GaD', 'GaG', 'GaD']

    'GbCpDaGaD' -> ['GbCpD', 'DaG', 'GaD']

    'CrCbGiGaDrD' -> ['CrC', 'CbG', 'GiG', 'GaD', 'DrD']

    """

    def add_head_tail(metapath, indices):

        """Makes sure that all metanodes are included in segments.

        Ensures that the first segment goes all the way back to the

        first metanode. Similarly, makes sure that the last segment

        includes all metanodes up to the last one."""

        # handle non-duplicated on the front

        if indices[0][0] != 0:

            indices = [(0, indices[0][0])] + indices

        # handle non-duplicated on the end

        if indices[-1][-1] != len(metapath):

            indices = indices + [(indices[-1][-1], len(metapath))]

        return indices

    metapath = metagraph.get_metapath(metapath)

    category = categorize(metapath)

    metanodes = metapath.get_nodes()

    freq = collections.Counter(metanodes)

    repeated = {i for i in freq.keys() if freq[i] > 1}

    if category == "no_repeats":

        return [metapath]

    elif category == "repeat_around":

        # Note this is hard-coded and will need to be updated for various

        # metapath lengths

        indices = [[0, 1], [1, 4], [4, 5]]

    elif category == "disjoint_groups":

        # CCBABA or CCBAAB or BABACC or BAABCC -> [CC, BABA], etc.

        metanodes = list(metapath.get_nodes())

        grouped = [list(v) for k, v in itertools.groupby(metanodes)]

        indices = (

            [[0, 1], [1, 2], [2, 5]]

            if len(grouped[0]) == 2

            else [[0, 3], [3, 4], [4, 5]]

        )

    elif category in ("disjoint", "short_repeat", "long_repeat"):

        indices = sorted(

            [metanodes.index(i), len(metapath) - list(reversed(metanodes)).index(i)]

            for i in repeated

        )

        indices = add_head_tail(metapath, indices)

        # handle middle cases with non-repeated nodes between disjoint regions

        # Eg. [[0,2], [3,4]] -> [[0,2],[2,3],[3,4]]

        inds = []

        for i, v in enumerate(indices[:-1]):

            inds.append(v)

            if v[-1] != indices[i + 1][0]:

                inds.append([v[-1], indices[i + 1][0]])

        indices = inds + [indices[-1]]

    elif category == "four_repeat":

        nodes = set(metanodes)

        repeat_indices = [

            [i for i, v in enumerate(metanodes) if v == metanode] for metanode in nodes

        ]

        repeat_indices = [i for i in repeat_indices if len(i) > 1]

        simple_repeats = [i for group in repeat_indices for i in group]

        seconds = simple_repeats[1:] + [simple_repeats[-1]]

        indices = list(zip(simple_repeats, seconds))

        indices = add_head_tail(metapath, indices)

    elif category in ("BAAB", "BABA", "other", "interior_complete_group"):

        nodes = set(metanodes)

        repeat_indices = [

            [i for i, v in enumerate(metanodes) if v == metanode] for metanode in nodes

        ]

        repeat_indices = [i for i in repeat_indices if len(i) > 1]

        simple_repeats = [i for group in repeat_indices for i in group]

        inds = []

        for i in repeat_indices:

            if len(i) == 2:

                inds += i

            if len(i) > 2:

                inds.append(i[0])

                inds.append(i[-1])

                for j in i[1:-1]:

                    if (j - 1 in simple_repeats and j + 1 in simple_repeats) and not (

                        j - 1 in i and j + 1 in i

                    ):

                        inds.append(j)

        inds = sorted(inds)

        seconds = inds[1:] + [inds[-1]]

        indices = list(zip(inds, seconds))

        indices = [i for i in indices if len(set(i)) == 2]

        indices = add_head_tail(metapath, indices)

    segments = [metapath[i[0] : i[1]] for i in indices]

    segments = [i for i in segments if i]

    segments = [metagraph.get_metapath(metaedges) for metaedges in segments]

    # eg: B CC ABA

    if category == "interior_complete_group":

        segs = []

        for i, v in enumerate(segments[:-1]):

            if segments[i + 1].source() == segments[i + 1].target():

                edges = v.edges + segments[i + 1].edges + segments[i + 2].edges

                segs.append(metagraph.get_metapath(edges))

            elif v.source() == v.target():

                pass

            elif segments[i - 1].source() == segments[i - 1].target():

                pass

            else:

                segs.append(v)

        segs.append(segments[-1])

        segments = segs

    return segments

def get_all_segments(metagraph, metapath):

    """

    Return all subsegments of a given metapath, including those segments that

    appear only after early splits.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapath : hetnetpy.hetnet.MetaPath

    Returns

    -------

    list

    Example

    -------

    >>> get_all_segments(metagraph, CrCbGaDrDaG)

    [CrC, CbG, GaDrDaG, GaD, DrD, DaG]

    """

    metapath = metagraph.get_metapath(metapath)

    segments = get_segments(metagraph, metapath)

    if len(segments) == 1:

        return [metapath]

    all_subsegments = [metapath]

    for segment in segments:

        subsegments = get_all_segments(metagraph, segment)

        next_split = subsegments if len(subsegments) > 1 else []

        all_subsegments = all_subsegments + [segment] + next_split

    return all_subsegments

def order_segments(metagraph, metapaths, store_inverses=False):

    """

    Gives the frequencies of metapath segments that occur when computing DWPC.

    In DWPC computation, metapaths are split a number of times for simpler computation.

    This function finds the frequencies that segments would be used when computing

    DWPC for all given metapaths. For the targeted caching of the most frequently

    used segments.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapaths : list

        list of hetnetpy.hetnet.MetaPath objects

    store_inverses : bool

        Whether or not to include both forward and backward directions of segments.

        For example, if False: [CbG, GbC] -> [CbG, CbG], else no change.

    Returns

    -------

    collections.Counter

        Number of times each metapath segment appears when getting all segments.

    """

    all_segments = [

        segment

        for metapath in metapaths

        for segment in get_all_segments(metagraph, metapath)

    ]

    if not store_inverses:

        # Change all instances of inverted segments to the same direction, using a first-seen ordering

        seen = set()

        aligned_segments = list()

        for segment in all_segments:

            add = segment.inverse if segment.inverse in seen else segment

            aligned_segments.append(add)

            seen.add(add)

        all_segments = aligned_segments

    segment_counts = collections.Counter(all_segments)

    return segment_counts

def remove_diag(mat, dtype=numpy.float64):

    """Set the main diagonal of a square matrix to zeros."""

    assert mat.shape[0] == mat.shape[1]  # must be square

    if sparse.issparse(mat):

        return mat - sparse.diags(mat.diagonal(), dtype=dtype)

    else:

        return mat - numpy.diag(mat.diagonal())

def _degree_weight(matrix, damping, copy=True, dtype=numpy.float64):

    """Normalize an adjacency matrix by the in and out degree."""

    matrix = hetmatpy.matrix.copy_array(matrix, copy, dtype=dtype)

    row_sums = numpy.array(matrix.sum(axis=1), dtype=dtype).flatten()

    column_sums = numpy.array(matrix.sum(axis=0), dtype=dtype).flatten()

    matrix = hetmatpy.matrix.normalize(matrix, row_sums, "rows", damping)

    matrix = hetmatpy.matrix.normalize(matrix, column_sums, "columns", damping)

    return matrix

def _dwpc_approx(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    Compute an approximation of DWPC. Only removes the diagonal for the first

    repeated node, and any disjoint repetitions that follow the last occurrence

    of the first repeating node.

    Examples

    --------

    GiGbCrC -> Identical output to DWPC

    GiGbCbGiG -> Approximation

    """

    dwpc_matrix = None

    row_names = None

    # Find the first repeated metanode and where it occurs

    nodes = metapath.get_nodes()

    repeated_nodes = [node for i, node in enumerate(nodes) if node in nodes[i + 1 :]]

    first_repeat = repeated_nodes[0]

    repeated_indices = [i for i, v in enumerate(nodes) if v == first_repeat]

    for i, segment in enumerate(repeated_indices[1:]):

        rows, cols, dwpc_matrix = dwpc(

            graph,

            metapath[repeated_indices[i] : segment],

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        if row_names is None:

            row_names = rows

    # Add head and tail segments, if applicable

    if repeated_indices[0] != 0:

        row_names, _, head_seg = dwwc(

            graph,

            metapath[0 : repeated_indices[0]],

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = head_seg @ dwpc_matrix

    if nodes[repeated_indices[-1]] != nodes[-1]:

        _, cols, tail_seg = dwpc(

            graph,

            metapath[repeated_indices[-1] :],

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = dwpc_matrix @ tail_seg

    dwpc_matrix = sparsify_or_densify(dwpc_matrix, dense_threshold)

    return row_names, cols, dwpc_matrix

def _dwpc_disjoint(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """DWPC for disjoint repeats or disjoint groups"""

    segments = get_segments(graph.metagraph, metapath)

    row_names = None

    col_names = None

    dwpc_matrix = None

    for segment in segments:

        rows, cols, seg_matrix = dwpc(

            graph,

            segment,

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        if row_names is None:

            row_names = rows

        if segment is segments[-1]:

            col_names = cols

        if dwpc_matrix is None:

            dwpc_matrix = seg_matrix

        else:

            dwpc_matrix = dwpc_matrix @ seg_matrix

    return row_names, col_names, dwpc_matrix

def _dwpc_repeat_around(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    DWPC for situations in which we have a surrounding repeat like

    B----B, where the middle group is a more complicated group. The

    purpose of this function is just as an order-of-operations simplification

    """

    segments = get_segments(graph.metagraph, metapath)

    mid = dwpc(

        graph,

        segments[1],

        damping=damping,

        dense_threshold=dense_threshold,

        dtype=dtype,

    )[2]

    row_names, cols, adj0 = dwpc(

        graph,

        segments[0],

        damping=damping,

        dense_threshold=dense_threshold,

        dtype=dtype,

    )

    rows, col_names, adj1 = dwpc(

        graph,

        segments[-1],

        damping=damping,

        dense_threshold=dense_threshold,

        dtype=dtype,

    )

    dwpc_matrix = remove_diag(adj0 @ mid @ adj1, dtype=dtype)

    return row_names, col_names, dwpc_matrix

def _dwpc_baab(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    A function to handle metapath (segments) of the form BAAB.

    This function will handle arbitrary lengths of this repeated

    pattern. For example, ABCCBA, ABCDDCBA, etc. all work with this

    function. Random non-repeat inserts are supported. The metapath

    must start and end with a repeated node, though.

    Covers all variants of symmetrically repeated metanodes with

    support for random non-repeat metanode inserts at any point.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold above which a sparse matrix will be

        converted to a dense automatically.

    dtype : dtype object

    Examples

    --------

    Acceptable metapaths forms include the following:

    B-A-A-B

    B-C-A-A-B

    B-C-A-D-A-E-B

    B-C-D-E-A-F-A-B

    C-B-A-A-B-D-E

    """

    # Segment the metapath

    segments = get_segments(graph.metagraph, metapath)

    # Start with the middle group (A-A or A-...-A in BAAB)

    for i, s in enumerate(segments):

        if s.source() == s.target():

            mid_seg = s

            mid_ind = i

    rows, cols, dwpc_mid = dwpc(

        graph, mid_seg, damping=damping, dense_threshold=dense_threshold, dtype=dtype

    )

    dwpc_mid = remove_diag(dwpc_mid, dtype=dtype)

    # Get two indices for the segments ahead of and behind the middle region

    head_ind = mid_ind

    tail_ind = mid_ind

    while head_ind > 0 or tail_ind < len(segments):

        head_ind -= 1

        tail_ind += 1

        head = segments[head_ind] if head_ind >= 0 else None

        tail = segments[tail_ind] if tail_ind < len(segments) else None

        # Multiply on the head

        if head is not None:

            row_names, cols, dwpc_head = dwpc(

                graph,

                head,

                damping=damping,

                dense_threshold=dense_threshold,

                dtype=dtype,

            )

            dwpc_mid = dwpc_head @ dwpc_mid

        # Multiply on the tail

        if tail is not None:

            rows, col_names, dwpc_tail = dwpc(

                graph,

                tail,

                damping=damping,

                dense_threshold=dense_threshold,

                dtype=dtype,

            )

            dwpc_mid = dwpc_mid @ dwpc_tail

        # Remove the diagonal if the head and tail are repeats

        if head and tail:

            if head.source() == tail.target():

                dwpc_mid = remove_diag(dwpc_mid, dtype=dtype)

    return row_names, col_names, dwpc_mid

def _dwpc_baba(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    Computes the degree-weighted path count for overlapping metanode

    repeats of the form B-A-B-A. Supports random inserts.

    Segment must start with B and end with A. AXBYAZB

    Also supports four-node repeats of a single node, including random,

    non-repeated inserts. For example, ABBBXBC, AAAA.

    """

    segments = get_segments(graph.metagraph, metapath)

    seg_axb = None

    for i, s in enumerate(segments[:-2]):

        if s.source() == segments[i + 2].source() and not seg_axb:

            seg_axb = s

            seg_bya = segments[i + 1]

            seg_azb = segments[i + 2]

            seg_cda = segments[0] if i == 1 else None

            seg_bed = segments[-1] if segments[-1] != seg_azb else None

    # Collect segment DWPC and corrections

    row_names, cols, axb = dwpc(

        graph, seg_axb, damping=damping, dense_threshold=dense_threshold, dtype=dtype

    )

    rows, cols, bya = dwpc(

        graph, seg_bya, damping=damping, dense_threshold=dense_threshold, dtype=dtype

    )

    rows, col_names, azb = dwpc(

        graph, seg_azb, damping=damping, dense_threshold=dense_threshold, dtype=dtype

    )

    correction_a = (

        numpy.diag((axb @ bya).diagonal()) @ azb

        if not sparse.issparse(axb)

        else sparse.diags((axb @ bya).diagonal()) @ azb

    )

    correction_b = (

        axb @ numpy.diag((bya @ azb).diagonal())

        if not sparse.issparse(bya)

        else axb @ sparse.diags((bya @ azb).diagonal())

    )

    correction_c = (

        axb * bya.T * azb

        if not sparse.issparse(bya)

        else (axb.multiply(bya.T)).multiply(azb)

    )

    # Apply the corrections

    dwpc_matrix = axb @ bya @ azb - correction_a - correction_b + correction_c

    if seg_axb.source == seg_azb.target:

        dwpc_matrix = remove_diag(dwpc_matrix)

    # Account for possible head and tail segments outside the BABA group

    if seg_cda is not None:

        row_names, cols, cda = dwpc(

            graph,

            seg_cda,

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = cda @ dwpc_matrix

    if seg_bed is not None:

        rows, col_names, bed = dwpc(

            graph,

            seg_bed,

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = dwpc_matrix @ bed

    return row_names, col_names, dwpc_matrix

def _dwpc_short_repeat(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    One metanode repeated 3 or fewer times (A-A-A), not (A-A-A-A)

    This can include other random inserts, so long as they are not

    repeats. Must start and end with the repeated node. Acceptable

    examples: (A-B-A-A), (A-B-A-C-D-E-F-A), (A-B-A-A), etc.

    """

    segments = get_segments(graph.metagraph, metapath)

    assert len(segments) <= 3

    # Account for different head and tail possibilities.

    head_segment = None

    tail_segment = None

    dwpc_matrix = None

    dwpc_tail = None

    # Label the segments as head, tail, and repeat

    for i, segment in enumerate(segments):

        if segment.source() == segment.target():

            repeat_segment = segment

        else:

            if i == 0:

                head_segment = segment

            else:

                tail_segment = segment

    # Calculate DWPC for the middle ("repeat") segment

    repeated_metanode = repeat_segment.source()

    index_of_repeats = [

        i for i, v in enumerate(repeat_segment.get_nodes()) if v == repeated_metanode

    ]

    for metaedge in repeat_segment[: index_of_repeats[1]]:

        rows, cols, adj = hetmatpy.matrix.metaedge_to_adjacency_matrix(

            graph, metaedge, dtype=dtype, dense_threshold=dense_threshold

        )

        adj = _degree_weight(adj, damping, dtype=dtype)

        if dwpc_matrix is None:

            row_names = rows

            dwpc_matrix = adj

        else:

            dwpc_matrix = dwpc_matrix @ adj

    dwpc_matrix = remove_diag(dwpc_matrix, dtype=dtype)

    # Extra correction for random metanodes in the repeat segment

    if len(index_of_repeats) == 3:

        for metaedge in repeat_segment[index_of_repeats[1] :]:

            rows, cols, adj = hetmatpy.matrix.metaedge_to_adjacency_matrix(

                graph, metaedge, dtype=dtype, dense_threshold=dense_threshold

            )

            adj = _degree_weight(adj, damping, dtype=dtype)

            if dwpc_tail is None:

                dwpc_tail = adj

            else:

                dwpc_tail = dwpc_tail @ adj

        dwpc_tail = remove_diag(dwpc_tail, dtype=dtype)

        dwpc_matrix = dwpc_matrix @ dwpc_tail

        dwpc_matrix = remove_diag(dwpc_matrix, dtype=dtype)

    col_names = cols

    if head_segment:

        row_names, cols, head_dwpc = dwpc(

            graph,

            head_segment,

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = head_dwpc @ dwpc_matrix

    if tail_segment:

        rows, col_names, tail_dwpc = dwpc(

            graph,

            tail_segment,

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

        )

        dwpc_matrix = dwpc_matrix @ tail_dwpc

    return row_names, col_names, dwpc_matrix

def _node_to_children(

    graph, metapath, node, metapath_index, damping=0, history=None, dtype=numpy.float64

):

    """

    Returns a history adjusted list of child nodes. Used in _dwpc_general_case.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    node : numpy.ndarray

    metapath_index : int

    damping : float

    history : numpy.ndarray

    dtype : dtype object

    Returns

    -------

    dict

        List of child nodes and a single numpy.ndarray of the newly

        updated history vector.

    """

    metaedge = metapath[metapath_index]

    metanodes = list(metapath.get_nodes())

    freq = collections.Counter(metanodes)

    repeated = {i for i in freq.keys() if freq[i] > 1}

    if history is None:

        history = {

            i.target: numpy.ones(

                len(hetmatpy.matrix.metaedge_to_adjacency_matrix(graph, i)[1]),

                dtype=dtype,

            )

            for i in metapath

            if i.target in repeated

        }

    history = history.copy()

    if metaedge.source in history:

        history[metaedge.source] -= numpy.array(node != 0, dtype=dtype)

    rows, cols, adj = hetmatpy.matrix.metaedge_to_adjacency_matrix(

        graph, metaedge, dtype=dtype

    )

    adj = _degree_weight(adj, damping, dtype=dtype)

    vector = node @ adj

    if metaedge.target in history:

        vector *= history[metaedge.target]

    children = [i for i in numpy.diag(vector) if i.any()]

    return {"children": children, "history": history, "next_index": metapath_index + 1}

def _dwpc_general_case(graph, metapath, damping=0, dtype=numpy.float64):

    """

    A slow but general function to compute the degree-weighted

    path count. Works by splitting the metapath at junctions

    where one node is joined to multiple nodes over a metaedge.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dtype : dtype object

    """

    dwpc_step = functools.partial(

        _node_to_children, graph=graph, metapath=metapath, damping=damping, dtype=dtype

    )

    start_nodes, cols, adj = hetmatpy.matrix.metaedge_to_adjacency_matrix(

        graph, metapath[0]

    )

    rows, fin_nodes, adj = hetmatpy.matrix.metaedge_to_adjacency_matrix(

        graph, metapath[-1]

    )

    number_start = len(start_nodes)

    number_end = len(fin_nodes)

    dwpc_matrix = []

    if len(metapath) > 1:

        for i in range(number_start):

            search = numpy.zeros(number_start, dtype=dtype)

            search[i] = 1

            step1 = [dwpc_step(node=search, metapath_index=0, history=None)]

            k = 1

            while k < len(metapath):

                k += 1

                step2 = []

                for group in step1:

                    for child in group["children"]:

                        hist = copy.deepcopy(group["history"])

                        out = dwpc_step(

                            node=child, metapath_index=group["next_index"], history=hist

                        )

                        if out["children"]:

                            step2.append(out)

                    step1 = step2

            final_children = [group for group in step2 if group["children"] != []]

            end_nodes = sum(

                child for group in final_children for child in group["children"]

            )

            if type(end_nodes) not in (list, numpy.ndarray):

                end_nodes = numpy.zeros(number_end)

            dwpc_matrix.append(end_nodes)

    else:

        dwpc_matrix = _degree_weight(adj, damping=damping, dtype=dtype)

    dwpc_matrix = numpy.array(dwpc_matrix, dtype=dtype)

    return start_nodes, fin_nodes, dwpc_matrix

# Default DWWC method to use, when not specified

default_dwwc_method = dwwc_chain

Functions

categorize

def categorize(
    metapath
)

Returns the classification of a given metapath as one of

a set of metapath types which we approach differently.

Parameters:

Name Type Description Default
metapath hetnetpy.hetnet.MetaPath None None

Returns:

Type Description
string One of ['no_repeats', 'disjoint', 'short_repeat',
'long_repeat', 'BAAB', 'BABA', 'repeat_around',
'interior_complete_group', 'disjoint_groups', 'other']
View Source
def categorize(metapath):

    """

    Returns the classification of a given metapath as one of

    a set of metapath types which we approach differently.

    Parameters

    ----------

    metapath : hetnetpy.hetnet.MetaPath

    Returns

    -------

    classification : string

        One of ['no_repeats', 'disjoint', 'short_repeat',

                'long_repeat', 'BAAB', 'BABA', 'repeat_around',

                 'interior_complete_group', 'disjoint_groups', 'other']

    Examples

    --------

    GbCtDlA -> 'no_repeats'

    GiGiG   -> 'short_repeat'

    GiGiGcG -> 'four_repeat'

    GiGcGiGiG -> 'long_repeat'

    GiGbCrC -> 'disjoint'

    GbCbGbC -> 'BABA'

    GbCrCbG -> 'BAAB'

    DaGiGbCrC -> 'disjoint'

    GiGaDpCrC -> 'disjoint'

    GiGbCrCpDrD -> 'disjoint'

    GbCpDaGbCpD -> 'other'

    GbCrCrCrCrCbG -> 'other'

    """

    metanodes = list(metapath.get_nodes())

    freq = collections.Counter(metanodes)

    repeated = {metanode for metanode, count in freq.items() if count > 1}

    if not repeated:

        return "no_repeats"

    repeats_only = [node for node in metanodes if node in repeated]

    # Group neighbors if they are the same

    grouped = [list(v) for k, v in itertools.groupby(repeats_only)]

    # Handle multiple disjoint repeats, any number, ie. AA,BB,CC,DD,...

    if len(grouped) == len(repeated):

        # Identify if there is only one metanode

        if len(repeated) == 1:

            if max(freq.values()) < 4:

                return "short_repeat"

            elif max(freq.values()) == 4:

                return "four_repeat"

            else:

                return "long_repeat"

        return "disjoint"

    assert len(repeats_only) > 3

    # Categorize the reformatted metapath

    if len(repeats_only) == 4:

        if repeats_only[0] == repeats_only[-1]:

            assert repeats_only[1] == repeats_only[2]

            return "BAAB"

        else:

            assert (

                repeats_only[0] == repeats_only[2]

                and repeats_only[1] == repeats_only[3]

            )

            return "BABA"

    elif len(repeats_only) == 5 and max(map(len, grouped)) == 3:

        if repeats_only[0] == repeats_only[-1]:

            return "BAAB"

    elif repeats_only == list(reversed(repeats_only)) and not len(repeats_only) % 2:

        return "BAAB"

    # 6 node paths with 3x2 repeats

    elif len(repeated) == 3 and len(metapath) == 5:

        if repeats_only[0] == repeats_only[-1]:

            return "repeat_around"

        # AABCCB or AABCBC

        elif len(grouped[0]) == 2 or len(grouped[-1]) == 2:

            return "disjoint_groups"

        # ABA CC B

        elif len(repeats_only) - len(grouped) == 1:

            return "interior_complete_group"

        # most complicated len 6

        else:

            return "other"

    else:

        # Multi-repeats that aren't disjoint, eg. ABCBAC

        if len(repeated) > 2:

            logging.info(

                f"{metapath}: Only two overlapping repeats currently supported"

            )

            return "other"

        if len(metanodes) > 4:

            logging.info(

                f"{metapath}: Complex metapaths of length > 4 are not yet " f"supported"

            )

            return "other"

        assert False

default_dwwc_method

def default_dwwc_method(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    dtype=<class 'numpy.float64'>
)

Uses optimal matrix chain multiplication as in numpy.multi_dot, but allows

for sparse matrices. Uses ordering modified from numpy.linalg.linalg._multi_dot (https://git.io/vh31f) which is released under a 3-Clause BSD License (https://git.io/vhCDC).

View Source
def dwwc_chain(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    Uses optimal matrix chain multiplication as in numpy.multi_dot, but allows

    for sparse matrices. Uses ordering modified from numpy.linalg.linalg._multi_dot

    (https://git.io/vh31f) which is released under a 3-Clause BSD License

    (https://git.io/vhCDC).

    """

    metapath = graph.metagraph.get_metapath(metapath)

    array_dims = [graph.count_nodes(mn) for mn in metapath.get_nodes()]

    row_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.source())

    columns_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.target())

    ordering = _dimensions_to_ordering(array_dims)

    dwwc_matrix = _multi_dot(

        metapath, ordering, 0, len(metapath) - 1, graph, damping, dense_threshold, dtype

    )

    dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return row_ids, columns_ids, dwwc_matrix

dwpc

def dwpc(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    approx_ok=False,
    dtype=<class 'numpy.float64'>,
    dwwc_method=None
)

A unified function to compute the degree-weighted path count.

This function will call get_segments, then the appropriate specialized (or generalized) DWPC function.

Parameters:

Name Type Description Default
graph hetnetpy.hetnet.Graph None None
metapath hetnetpy.hetnet.MetaPath None None
damping float None None
dense_threshold float (0 <= dense_threshold <= 1) sets the density threshold above which a sparse matrix will be
converted to a dense automatically. None
approx_ok bool if True, uses an approximation to DWPC. If False, dwpc will call
_dwpc_general_case and give a warning on metapaths which are
categorized 'other' and 'long_repeat'.. None
dtype dtype object numpy.float32 or numpy.float64. At present, numpy.float16 fails when
using sparse matrices, due to a bug in scipy.sparse None
dwwc_method function dwwc method to use for computing DWWCs. If set to None, use
module-level default (default_dwwc_method). _dwwc_method

Returns:

Type Description
numpy.ndarray row labels
View Source
@path_count_cache(metric="dwpc")

def dwpc(

    graph,

    metapath,

    damping=0.5,

    dense_threshold=0,

    approx_ok=False,

    dtype=numpy.float64,

    dwwc_method=None,

):

    """

    A unified function to compute the degree-weighted path count.

    This function will call get_segments, then the appropriate

    specialized (or generalized) DWPC function.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold above which a sparse matrix will be

        converted to a dense automatically.

    approx_ok : bool

        if True, uses an approximation to DWPC. If False, dwpc will call

        _dwpc_general_case and give a warning on metapaths which are

        categorized 'other' and 'long_repeat'..

    dtype : dtype object

        numpy.float32 or numpy.float64. At present, numpy.float16 fails when

        using sparse matrices, due to a bug in scipy.sparse

    dwwc_method : function

        dwwc method to use for computing DWWCs. If set to None, use

        module-level default (default_dwwc_method).

    Returns

    -------

    numpy.ndarray

        row labels

    numpy.ndarray

        column labels

    numpy.ndarray or scipy.sparse.csc_matrix

        the DWPC matrix

    """

    category = categorize(metapath)

    dwpc_function = _category_to_function(category, dwwc_method=dwwc_method)

    if category in ("long_repeat", "other"):

        if approx_ok:

            dwpc_function = _dwpc_approx

        else:

            logging.warning(

                f"Metapath {metapath} will use _dwpc_general_case, "

                "which can require very long computations."

            )

    row_names, col_names, dwpc_matrix = dwpc_function(

        graph, metapath, damping, dense_threshold=dense_threshold, dtype=dtype

    )

    return row_names, col_names, dwpc_matrix

dwwc

def dwwc(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    dtype=<class 'numpy.float64'>,
    dwwc_method=None
)

Compute the degree-weighted walk count (DWWC) in which nodes can be

repeated within a path.

Parameters:

Name Type Description Default
graph hetnetpy.hetnet.Graph None None
metapath hetnetpy.hetnet.MetaPath None None
damping float None None
dense_threshold float (0 <= dense_threshold <= 1) sets the density threshold at which a sparse matrix will be
converted to a dense automatically. None
dtype dtype object None None
dwwc_method function dwwc method to use for computing DWWCs. If set to None, use
module-level default (default_dwwc_method). _dwwc_method
View Source
@path_count_cache(metric="dwwc")

def dwwc(

    graph,

    metapath,

    damping=0.5,

    dense_threshold=0,

    dtype=numpy.float64,

    dwwc_method=None,

):

    """

    Compute the degree-weighted walk count (DWWC) in which nodes can be

    repeated within a path.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold at which a sparse matrix will be

        converted to a dense automatically.

    dtype : dtype object

    dwwc_method : function

        dwwc method to use for computing DWWCs. If set to None, use

        module-level default (default_dwwc_method).

    """

    return dwwc_method(

        graph=graph,

        metapath=metapath,

        damping=damping,

        dense_threshold=dense_threshold,

        dtype=dtype,

    )

dwwc_chain

def dwwc_chain(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    dtype=<class 'numpy.float64'>
)

Uses optimal matrix chain multiplication as in numpy.multi_dot, but allows

for sparse matrices. Uses ordering modified from numpy.linalg.linalg._multi_dot (https://git.io/vh31f) which is released under a 3-Clause BSD License (https://git.io/vhCDC).

View Source
def dwwc_chain(graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64):

    """

    Uses optimal matrix chain multiplication as in numpy.multi_dot, but allows

    for sparse matrices. Uses ordering modified from numpy.linalg.linalg._multi_dot

    (https://git.io/vh31f) which is released under a 3-Clause BSD License

    (https://git.io/vhCDC).

    """

    metapath = graph.metagraph.get_metapath(metapath)

    array_dims = [graph.count_nodes(mn) for mn in metapath.get_nodes()]

    row_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.source())

    columns_ids = hetmatpy.matrix.get_node_identifiers(graph, metapath.target())

    ordering = _dimensions_to_ordering(array_dims)

    dwwc_matrix = _multi_dot(

        metapath, ordering, 0, len(metapath) - 1, graph, damping, dense_threshold, dtype

    )

    dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return row_ids, columns_ids, dwwc_matrix

dwwc_recursive

def dwwc_recursive(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    dtype=<class 'numpy.float64'>
)

Recursive DWWC implementation to take better advantage of caching.

View Source
def dwwc_recursive(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    Recursive DWWC implementation to take better advantage of caching.

    """

    rows, cols, adj_mat = hetmatpy.matrix.metaedge_to_adjacency_matrix(

        graph, metapath[0], dense_threshold=dense_threshold, dtype=dtype

    )

    adj_mat = _degree_weight(adj_mat, damping, dtype=dtype)

    if len(metapath) > 1:

        _, cols, dwwc_next = dwwc(

            graph,

            metapath[1:],

            damping=damping,

            dense_threshold=dense_threshold,

            dtype=dtype,

            dwwc_method=dwwc_recursive,

        )

        dwwc_matrix = adj_mat @ dwwc_next

    else:

        dwwc_matrix = adj_mat

    dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return rows, cols, dwwc_matrix

dwwc_sequential

def dwwc_sequential(
    graph,
    metapath,
    damping=0.5,
    dense_threshold=0,
    dtype=<class 'numpy.float64'>
)

Compute the degree-weighted walk count (DWWC) in which nodes can be

repeated within a path.

Parameters:

Name Type Description Default
graph hetnetpy.hetnet.Graph None None
metapath hetnetpy.hetnet.MetaPath None None
damping float None None
dense_threshold float (0 <= dense_threshold <= 1) sets the density threshold at which a sparse matrix will be
converted to a dense automatically. None
dtype dtype object None None
View Source
def dwwc_sequential(

    graph, metapath, damping=0.5, dense_threshold=0, dtype=numpy.float64

):

    """

    Compute the degree-weighted walk count (DWWC) in which nodes can be

    repeated within a path.

    Parameters

    ----------

    graph : hetnetpy.hetnet.Graph

    metapath : hetnetpy.hetnet.MetaPath

    damping : float

    dense_threshold : float (0 <= dense_threshold <= 1)

        sets the density threshold at which a sparse matrix will be

        converted to a dense automatically.

    dtype : dtype object

    """

    dwwc_matrix = None

    row_names = None

    for metaedge in metapath:

        rows, cols, adj_mat = hetmatpy.matrix.metaedge_to_adjacency_matrix(

            graph, metaedge, dense_threshold=dense_threshold, dtype=dtype

        )

        adj_mat = _degree_weight(adj_mat, damping, dtype=dtype)

        if dwwc_matrix is None:

            row_names = rows

            dwwc_matrix = adj_mat

        else:

            dwwc_matrix = dwwc_matrix @ adj_mat

            dwwc_matrix = sparsify_or_densify(dwwc_matrix, dense_threshold)

    return row_names, cols, dwwc_matrix

get_all_segments

def get_all_segments(
    metagraph,
    metapath
)

Return all subsegments of a given metapath, including those segments that

appear only after early splits.

Parameters:

Name Type Description Default
metagraph hetnetpy.hetnet.MetaGraph None None
metapath hetnetpy.hetnet.MetaPath None None

Returns:

Type Description
list None
View Source
def get_all_segments(metagraph, metapath):

    """

    Return all subsegments of a given metapath, including those segments that

    appear only after early splits.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapath : hetnetpy.hetnet.MetaPath

    Returns

    -------

    list

    Example

    -------

    >>> get_all_segments(metagraph, CrCbGaDrDaG)

    [CrC, CbG, GaDrDaG, GaD, DrD, DaG]

    """

    metapath = metagraph.get_metapath(metapath)

    segments = get_segments(metagraph, metapath)

    if len(segments) == 1:

        return [metapath]

    all_subsegments = [metapath]

    for segment in segments:

        subsegments = get_all_segments(metagraph, segment)

        next_split = subsegments if len(subsegments) > 1 else []

        all_subsegments = all_subsegments + [segment] + next_split

    return all_subsegments

get_segments

def get_segments(
    metagraph,
    metapath
)

Split a metapath into segments of recognized groups and non-repeated

nodes. Groups include BAAB, BABA, disjoint short- and long-repeats. Returns an error for categorization 'other'.

Parameters:

Name Type Description Default
metagraph hetnetpy.hetnet.MetaGraph None None
metapath hetnetpy.hetnet.Metapath None None

Returns:

Type Description
list list of metapaths. If the metapath is not segmentable or is already
fully simplified (eg. GiGaDaG), then the list will have only one
element.
View Source
def get_segments(metagraph, metapath):

    """

    Split a metapath into segments of recognized groups and non-repeated

    nodes. Groups include BAAB, BABA, disjoint short- and long-repeats.

    Returns an error for categorization 'other'.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapath : hetnetpy.hetnet.Metapath

    Returns

    -------

    list

        list of metapaths. If the metapath is not segmentable or is already

        fully simplified (eg. GiGaDaG), then the list will have only one

        element.

    Examples

    --------

    'CbGaDaGaD' -> ['CbG', 'GaD', 'GaG', 'GaD']

    'GbCpDaGaD' -> ['GbCpD', 'DaG', 'GaD']

    'CrCbGiGaDrD' -> ['CrC', 'CbG', 'GiG', 'GaD', 'DrD']

    """

    def add_head_tail(metapath, indices):

        """Makes sure that all metanodes are included in segments.

        Ensures that the first segment goes all the way back to the

        first metanode. Similarly, makes sure that the last segment

        includes all metanodes up to the last one."""

        # handle non-duplicated on the front

        if indices[0][0] != 0:

            indices = [(0, indices[0][0])] + indices

        # handle non-duplicated on the end

        if indices[-1][-1] != len(metapath):

            indices = indices + [(indices[-1][-1], len(metapath))]

        return indices

    metapath = metagraph.get_metapath(metapath)

    category = categorize(metapath)

    metanodes = metapath.get_nodes()

    freq = collections.Counter(metanodes)

    repeated = {i for i in freq.keys() if freq[i] > 1}

    if category == "no_repeats":

        return [metapath]

    elif category == "repeat_around":

        # Note this is hard-coded and will need to be updated for various

        # metapath lengths

        indices = [[0, 1], [1, 4], [4, 5]]

    elif category == "disjoint_groups":

        # CCBABA or CCBAAB or BABACC or BAABCC -> [CC, BABA], etc.

        metanodes = list(metapath.get_nodes())

        grouped = [list(v) for k, v in itertools.groupby(metanodes)]

        indices = (

            [[0, 1], [1, 2], [2, 5]]

            if len(grouped[0]) == 2

            else [[0, 3], [3, 4], [4, 5]]

        )

    elif category in ("disjoint", "short_repeat", "long_repeat"):

        indices = sorted(

            [metanodes.index(i), len(metapath) - list(reversed(metanodes)).index(i)]

            for i in repeated

        )

        indices = add_head_tail(metapath, indices)

        # handle middle cases with non-repeated nodes between disjoint regions

        # Eg. [[0,2], [3,4]] -> [[0,2],[2,3],[3,4]]

        inds = []

        for i, v in enumerate(indices[:-1]):

            inds.append(v)

            if v[-1] != indices[i + 1][0]:

                inds.append([v[-1], indices[i + 1][0]])

        indices = inds + [indices[-1]]

    elif category == "four_repeat":

        nodes = set(metanodes)

        repeat_indices = [

            [i for i, v in enumerate(metanodes) if v == metanode] for metanode in nodes

        ]

        repeat_indices = [i for i in repeat_indices if len(i) > 1]

        simple_repeats = [i for group in repeat_indices for i in group]

        seconds = simple_repeats[1:] + [simple_repeats[-1]]

        indices = list(zip(simple_repeats, seconds))

        indices = add_head_tail(metapath, indices)

    elif category in ("BAAB", "BABA", "other", "interior_complete_group"):

        nodes = set(metanodes)

        repeat_indices = [

            [i for i, v in enumerate(metanodes) if v == metanode] for metanode in nodes

        ]

        repeat_indices = [i for i in repeat_indices if len(i) > 1]

        simple_repeats = [i for group in repeat_indices for i in group]

        inds = []

        for i in repeat_indices:

            if len(i) == 2:

                inds += i

            if len(i) > 2:

                inds.append(i[0])

                inds.append(i[-1])

                for j in i[1:-1]:

                    if (j - 1 in simple_repeats and j + 1 in simple_repeats) and not (

                        j - 1 in i and j + 1 in i

                    ):

                        inds.append(j)

        inds = sorted(inds)

        seconds = inds[1:] + [inds[-1]]

        indices = list(zip(inds, seconds))

        indices = [i for i in indices if len(set(i)) == 2]

        indices = add_head_tail(metapath, indices)

    segments = [metapath[i[0] : i[1]] for i in indices]

    segments = [i for i in segments if i]

    segments = [metagraph.get_metapath(metaedges) for metaedges in segments]

    # eg: B CC ABA

    if category == "interior_complete_group":

        segs = []

        for i, v in enumerate(segments[:-1]):

            if segments[i + 1].source() == segments[i + 1].target():

                edges = v.edges + segments[i + 1].edges + segments[i + 2].edges

                segs.append(metagraph.get_metapath(edges))

            elif v.source() == v.target():

                pass

            elif segments[i - 1].source() == segments[i - 1].target():

                pass

            else:

                segs.append(v)

        segs.append(segments[-1])

        segments = segs

    return segments

order_segments

def order_segments(
    metagraph,
    metapaths,
    store_inverses=False
)

Gives the frequencies of metapath segments that occur when computing DWPC.

In DWPC computation, metapaths are split a number of times for simpler computation. This function finds the frequencies that segments would be used when computing DWPC for all given metapaths. For the targeted caching of the most frequently used segments.

Parameters:

Name Type Description Default
metagraph hetnetpy.hetnet.MetaGraph None None
metapaths list list of hetnetpy.hetnet.MetaPath objects None
store_inverses bool Whether or not to include both forward and backward directions of segments.
For example, if False: [CbG, GbC] -> [CbG, CbG], else no change. None

Returns:

Type Description
collections.Counter Number of times each metapath segment appears when getting all segments.
View Source
def order_segments(metagraph, metapaths, store_inverses=False):

    """

    Gives the frequencies of metapath segments that occur when computing DWPC.

    In DWPC computation, metapaths are split a number of times for simpler computation.

    This function finds the frequencies that segments would be used when computing

    DWPC for all given metapaths. For the targeted caching of the most frequently

    used segments.

    Parameters

    ----------

    metagraph : hetnetpy.hetnet.MetaGraph

    metapaths : list

        list of hetnetpy.hetnet.MetaPath objects

    store_inverses : bool

        Whether or not to include both forward and backward directions of segments.

        For example, if False: [CbG, GbC] -> [CbG, CbG], else no change.

    Returns

    -------

    collections.Counter

        Number of times each metapath segment appears when getting all segments.

    """

    all_segments = [

        segment

        for metapath in metapaths

        for segment in get_all_segments(metagraph, metapath)

    ]

    if not store_inverses:

        # Change all instances of inverted segments to the same direction, using a first-seen ordering

        seen = set()

        aligned_segments = list()

        for segment in all_segments:

            add = segment.inverse if segment.inverse in seen else segment

            aligned_segments.append(add)

            seen.add(add)

        all_segments = aligned_segments

    segment_counts = collections.Counter(all_segments)

    return segment_counts

remove_diag

def remove_diag(
    mat,
    dtype=<class 'numpy.float64'>
)

Set the main diagonal of a square matrix to zeros.

View Source
def remove_diag(mat, dtype=numpy.float64):

    """Set the main diagonal of a square matrix to zeros."""

    assert mat.shape[0] == mat.shape[1]  # must be square

    if sparse.issparse(mat):

        return mat - sparse.diags(mat.diagonal(), dtype=dtype)

    else:

        return mat - numpy.diag(mat.diagonal())