Source code for ensign.query_decomp

# ENSIGN rights
"""Return all components containing the query label in the given mode.
Up to k components are printed where <query label> has a non-zero score.
"""

import os

import numpy as np
import pandas as pd

import ensign.cp_decomp as cpd
import ensign.ensign_io.ensign_logging as ensign_logging

log = ensign_logging.get_logger()

[docs]def query_decomp(factors, labels, modes_to_search, query_label): """Compute all components containing the query label in the given mode. Up to k components are printed where <query label> has a non-zero score. Parameters ---------- factors : list of lists of ndarrays Each item in the top-level list should contain the factor matrices of one tensor. Each set of factor matrices is a list of ndarrays. labels : list of lists of strings Each item in the top-level list should contain the labels of one tensor. Each set of labels is a list of lists of strings. modes_to_search : list of list of ints Each sublist should contain the mode indices to be searched for the tensor at that index in the outermost list. query_label : str Query label with which to query the decomposition. Returns ------- result : list of pandas DataFrames DataFrames are in the same order as the input arguments. DataFrames have columns 'Score', 'Mode', and 'Component'. Score is the non-zero value of the query-label. Component is the corresponding component in mode Mode. """ if isinstance(factors[0], np.ndarray) and not isinstance(modes_to_search[0], list): factors, labels, modes_to_search = [factors], [labels], [modes_to_search] res = [] for factor_matrices, label_maps, mode_ids in zip(factors, labels, modes_to_search): if len(factor_matrices) != len(label_maps): msg = 'The number of factor matrices (decomp_mode_<x>.txt files) does not equal the number of label maps (map_mode_<x>.txt files).' log.error(msg) raise IOError(msg) order = len(factor_matrices) # Validate mode ids modes = [] for mode_id in mode_ids: if mode_id not in list(range(order)): msg = "{} factor matrices (decomp_mode_<x>.txt files) and label maps (map_mode_<x>.txt files) were passed, mode {} does not exist.".format(order, mode_id) log.warning(msg) else: modes.append(mode_id) mode_ids = modes # Build `query_idxs` dictionary: { mode: mode_index_of_query_label } # a.k.a. rows indices of the factor matrix that correspond to the label query_idxs = {} for mode, label_list in enumerate(label_maps): try: query_idxs[mode] = label_list.index(query_label) except: pass # Get components of the decomposition where the query_label is non-zero query_results = [] query_modes = set(query_idxs.keys()).intersection(set(mode_ids)) # Intersection of modes the user specifies and the modes that contain the query_label for mode in query_modes: factor_matrix = factor_matrices[mode] # Find indices of components associated with non-zero scores of 'query_label' label_idx = query_idxs[mode] # Mode index of query_label (index of row of factor matrix) factor_row = factor_matrix[label_idx,:] # Row of factor matrix associated with query_label component_idxs = np.where(factor_row > 0)[0] # Indices of components where query_label is non-zero (indices of columns of factor matrix) # Compile scores for DataFrame with found indices for comp_id in component_idxs: score = factor_matrix[label_idx, comp_id] query_results.append([score, mode, comp_id]) res.append(pd.DataFrame(query_results, columns=['Score', 'Mode', 'Component'])) return res