Source code for ensign.query_decomp
# ENSIGN rights
"""Return all components containing the query label in the given mode.
Up to k components are printed where <query label> has a non-zero score.
"""
import os
import numpy as np
import pandas as pd
import ensign.cp_decomp as cpd
import ensign.ensign_io.ensign_logging as ensign_logging
log = ensign_logging.get_logger()
[docs]def query_decomp(factors, labels, modes_to_search, query_label):
"""Compute all components containing the query label in the given mode.
Up to k components are printed where <query label> has a
non-zero score.
Parameters
----------
factors : list of lists of ndarrays
Each item in the top-level list should contain the factor matrices of
one tensor. Each set of factor matrices is a list of ndarrays.
labels : list of lists of strings
Each item in the top-level list should contain the labels of one tensor.
Each set of labels is a list of lists of strings.
modes_to_search : list of list of ints
Each sublist should contain the mode indices to be searched for the tensor
at that index in the outermost list.
query_label : str
Query label with which to query the decomposition.
Returns
-------
result : list of pandas DataFrames
DataFrames are in the same order as the input arguments.
DataFrames have columns 'Score', 'Mode', and
'Component'. Score is the non-zero value of the query-label. Component is the
corresponding component in mode Mode.
"""
if isinstance(factors[0], np.ndarray) and not isinstance(modes_to_search[0], list):
factors, labels, modes_to_search = [factors], [labels], [modes_to_search]
res = []
for factor_matrices, label_maps, mode_ids in zip(factors, labels, modes_to_search):
if len(factor_matrices) != len(label_maps):
msg = 'The number of factor matrices (decomp_mode_<x>.txt files) does not equal the number of label maps (map_mode_<x>.txt files).'
log.error(msg)
raise IOError(msg)
order = len(factor_matrices)
# Validate mode ids
modes = []
for mode_id in mode_ids:
if mode_id not in list(range(order)):
msg = "{} factor matrices (decomp_mode_<x>.txt files) and label maps (map_mode_<x>.txt files) were passed, mode {} does not exist.".format(order, mode_id)
log.warning(msg)
else:
modes.append(mode_id)
mode_ids = modes
# Build `query_idxs` dictionary: { mode: mode_index_of_query_label }
# a.k.a. rows indices of the factor matrix that correspond to the label
query_idxs = {}
for mode, label_list in enumerate(label_maps):
try:
query_idxs[mode] = label_list.index(query_label)
except:
pass
# Get components of the decomposition where the query_label is non-zero
query_results = []
query_modes = set(query_idxs.keys()).intersection(set(mode_ids)) # Intersection of modes the user specifies and the modes that contain the query_label
for mode in query_modes:
factor_matrix = factor_matrices[mode]
# Find indices of components associated with non-zero scores of 'query_label'
label_idx = query_idxs[mode] # Mode index of query_label (index of row of factor matrix)
factor_row = factor_matrix[label_idx,:] # Row of factor matrix associated with query_label
component_idxs = np.where(factor_row > 0)[0] # Indices of components where query_label is non-zero (indices of columns of factor matrix)
# Compile scores for DataFrame with found indices
for comp_id in component_idxs:
score = factor_matrix[label_idx, comp_id]
query_results.append([score, mode, comp_id])
res.append(pd.DataFrame(query_results, columns=['Score', 'Mode', 'Component']))
return res