Source code for textacy.viz.termite

import numpy as np

from .. import errors
from ..utils import to_collection

try:
    import matplotlib.pyplot as plt
    import matplotlib.transforms

except ImportError:
    pass


RC_PARAMS = {
    "axes.axisbelow": True,
    "axes.edgecolor": ".8",
    "axes.facecolor": "white",
    "axes.grid": True,
    "axes.labelcolor": ".15",
    "axes.linewidth": 1.0,
    "figure.facecolor": "white",
    "font.family": ["sans-serif"],
    "font.sans-serif": ["Arial", "Liberation Sans", "sans-serif"],
    "grid.color": ".8",
    "grid.linestyle": "-",
    "image.cmap": "Greys",
    "legend.frameon": False,
    "legend.numpoints": 1,
    "legend.scatterpoints": 1,
    "lines.solid_capstyle": "round",
    "text.color": ".15",
    "xtick.color": ".15",
    "xtick.direction": "out",
    "xtick.major.size": 0.0,
    "xtick.minor.size": 0.0,
    "ytick.color": ".15",
    "ytick.direction": "out",
    "ytick.major.size": 0.0,
    "ytick.minor.size": 0.0,
}

COLOR_PAIRS = (
    (
        (0.65098041296005249, 0.80784314870834351, 0.89019608497619629),
        (0.12572087695201239, 0.47323337360924367, 0.707327968232772),
    ),
    (
        (0.68899655751153521, 0.8681737867056154, 0.54376011946622071),
        (0.21171857311445125, 0.63326415104024547, 0.1812226118410335),
    ),
    (
        (0.98320646005518297, 0.5980161709820524, 0.59423301088459368),
        (0.89059593116535862, 0.10449827132271793, 0.11108035462744099),
    ),
    (
        (0.99175701702342312, 0.74648213716698619, 0.43401768935077328),
        (0.99990772780250103, 0.50099192647372981, 0.0051211073118098693),
    ),
    (
        (0.78329874347238004, 0.68724338552531095, 0.8336793640080622),
        (0.42485198495434734, 0.2511495584950722, 0.60386007743723258),
    ),
    (
        (0.99760092286502611, 0.99489427150464516, 0.5965244373854468),
        (0.69411766529083252, 0.3490196168422699, 0.15686275064945221),
    ),
)


[docs]def draw_termite_plot( values_mat, col_labels, row_labels, *, highlight_cols=None, highlight_colors=None, save=False, rc_params=None, ): """ Make a "termite" plot, typically used for assessing topic models with a tabular layout that promotes comparison of terms both within and across topics. Args: values_mat (:class:`np.ndarray` or matrix): matrix of values with shape (# row labels, # col labels) used to size the dots on the grid col_labels (seq[str]): labels used to identify x-axis ticks on the grid row_labels(seq[str]): labels used to identify y-axis ticks on the grid highlight_cols (int or seq[int], optional): indices for columns to visually highlight in the plot with contrasting colors highlight_colors (tuple of 2-tuples): each 2-tuple corresponds to a pair of (light/dark) matplotlib-friendly colors used to highlight a single column; if not specified (default), a good set of 6 pairs are used save (str, optional): give the full /path/to/fname on disk to save figure rc_params (dict, optional): allow passing parameters to rc_context in matplotlib.plyplot, details in https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.rc_context.html Returns: :obj:`matplotlib.axes.Axes.axis`: Axis on which termite plot is plotted. Raises: ValueError: if more columns are selected for highlighting than colors or if any of the inputs' dimensions don't match References: Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite: Visualization techniques for assessing textual topic models." Proceedings of the International Working Conference on Advanced Visual Interfaces. ACM, 2012. See Also: :meth:`TopicModel.termite_plot() <textacy.tm.topic_model.TopicModel.termite_plot>` """ try: plt except NameError: raise ImportError( "`matplotlib` is not installed, so `textacy.viz` won't work; " "install it individually via `$ pip install matplotlib`, or " "along with textacy via `pip install textacy[viz]`." ) n_rows, n_cols = values_mat.shape max_val = np.max(values_mat) if n_rows != len(row_labels): raise ValueError( "values_mat and row_labels dimensions don't match: " f"{n_rows} vs. {len(row_labels)}" ) if n_cols != len(col_labels): raise ValueError( "values_mat and col_labels dimensions don't match: " f"{n_cols} vs. {len(col_labels)}" ) if highlight_colors is None: highlight_colors = COLOR_PAIRS if highlight_cols is not None: if isinstance(highlight_cols, int): highlight_cols = (highlight_cols,) elif len(highlight_cols) > len(highlight_colors): raise ValueError( f"no more than {len(highlight_colors)} columns " "may be highlighted at once" ) highlight_colors = {hc: COLOR_PAIRS[i] for i, hc in enumerate(highlight_cols)} _rc_params = RC_PARAMS.copy() if rc_params: _rc_params.update(rc_params) with plt.rc_context(RC_PARAMS): fig, ax = plt.subplots(figsize=(pow(n_cols, 0.8), pow(n_rows, 0.66))) _ = ax.set_yticks(range(n_rows)) yticklabels = ax.set_yticklabels(row_labels, fontsize=14, color="gray") if highlight_cols is not None: for i, ticklabel in enumerate(yticklabels): max_tick_val = max(values_mat[i, hc] for hc in range(n_cols)) for hc in highlight_cols: if max_tick_val > 0 and values_mat[i, hc] == max_tick_val: ticklabel.set_color(highlight_colors[hc][1]) ax.get_xaxis().set_ticks_position("top") _ = ax.set_xticks(range(n_cols)) xticklabels = ax.set_xticklabels( col_labels, fontsize=14, color="gray", rotation=-60, ha="right" ) # Create offset transform by 5 points in x direction dx = 10 / 72 dy = 0 offset = matplotlib.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans) for label in ax.xaxis.get_majorticklabels(): label.set_transform(label.get_transform() + offset) if highlight_cols is not None: gridlines = ax.get_xgridlines() for i, ticklabel in enumerate(xticklabels): if i in highlight_cols: ticklabel.set_color(highlight_colors[i][1]) gridlines[i].set_color(highlight_colors[i][0]) gridlines[i].set_alpha(0.5) for col_ind in range(n_cols): if highlight_cols is not None and col_ind in highlight_cols: ax.scatter( [col_ind for _ in range(n_rows)], [i for i in range(n_rows)], s=600 * (values_mat[:, col_ind] / max_val), alpha=0.5, linewidth=1, color=highlight_colors[col_ind][0], edgecolor=highlight_colors[col_ind][1], ) else: ax.scatter( [col_ind for _ in range(n_rows)], [i for i in range(n_rows)], s=600 * (values_mat[:, col_ind] / max_val), alpha=0.5, linewidth=1, color="lightgray", edgecolor="gray", ) _ = ax.set_xlim(left=-1, right=n_cols) _ = ax.set_ylim(bottom=-1, top=n_rows) ax.invert_yaxis() # otherwise, values/labels go from bottom to top if save: fig.savefig(save, bbox_inches="tight", dpi=100) return ax
[docs]def termite_df_plot( components, *, highlight_topics=None, n_terms=25, rank_terms_by="max", sort_terms_by="seriation", save=False, rc_params=None, ): """ Make a "termite" plot for assessing topic models using a tabular layout to promote comparison of terms both within and across topics. Args: components (:class:`pandas.DataFrame` or sparse matrix): corpus represented as a term-topic matrix with shape (n_terms, n_topics); should have terms as index and topics as column names topics (int or Sequence[int]): topic(s) to include in termite plot; if -1, all topics are included highlight_topics (str or Sequence[str]): names for up to 6 topics to visually highlight in the plot with contrasting colors n_terms (int): number of top terms to include in termite plot rank_terms_by ({'max', 'mean', 'var'}): argument to dataframe `agg` function, used to rank terms; the top-ranked ``n_terms`` are included in the plot sort_terms_by ({'seriation', 'weight', 'index', 'alphabetical'}): method used to vertically sort the selected top ``n_terms`` terms; the default ("seriation") groups similar terms together, which facilitates cross-topic assessment save (str): give the full /path/to/fname on disk to save figure rc_params (dict, optional): allow passing parameters to rc_context in matplotlib.plyplot, details in https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.rc_context.html Returns: ``matplotlib.axes.Axes.axis``: Axis on which termite plot is plotted. Raises: ValueError: if more than 6 topics are selected for highlighting, or an invalid value is passed for the sort_topics_by, rank_terms_by, and/or sort_terms_by params References: - Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite: Visualization techniques for assessing textual topic models." Proceedings of the International Working Conference on Advanced Visual Interfaces. ACM, 2012. - Fajwel Fogel, Alexandre d’Aspremont, and Milan Vojnovic. 2016. Spectral ranking using seriation. J. Mach. Learn. Res. 17, 1 (January 2016), 3013–3057. See Also: :func:`viz.termite_plot <textacy.viz.termite.termite_plot>` TODO: `rank_terms_by` other metrics, e.g. topic salience or relevance """ highlight_topics = to_collection(highlight_topics, str, list) if len(highlight_topics) > 6: raise ValueError("no more than 6 topics may be highlighted at once") # get column index of any topics to highlight in termite plot if highlight_topics is not None: highlight_cols = tuple(components.columns.get_loc(c) for c in highlight_topics) else: highlight_cols = None component_filter = components.loc[ components.agg(rank_terms_by, axis=1) .sort_values(ascending=False) .iloc[:n_terms] .index ] # get top term indices in sorted order if sort_terms_by == "weight": pass # already done elif sort_terms_by == "alphabetical": component_filter = component_filter.sort_index() elif sort_terms_by == "seriation": # calculate similarity matrix similarity = ( component_filter @ component_filter.T.pipe(lambda df: df - df.min().min()) ).values # compute Laplacian matrice and its 2nd eigenvector L = np.diag(similarity.sum(axis=1)) - similarity D, V = np.linalg.eigh(L) D = D[np.argsort(D)] V = V[:, np.argsort(D)] fiedler = V[:, 1] # get permutation corresponding to sorting the 2nd eigenvector component_filter = component_filter.reindex( index=[component_filter.index[i] for i in np.argsort(fiedler)], ) else: raise ValueError( errors.value_invalid_msg( "sort_terms_by", sort_terms_by, {"weight", "alphabetical", "seriation"}, ) ) # get topic and term labels topic_labels = component_filter.columns term_labels = component_filter.index # get topic-term weights to size dots term_topic_weights = component_filter.values return draw_termite_plot( term_topic_weights, topic_labels, term_labels, highlight_cols=highlight_cols, save=save, rc_params=rc_params, )