Source code for radvel.plot.mcmc_plots

from __future__ import annotations

from types import ModuleType
import numpy as np
import corner
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import pyplot as pl
from matplotlib import rcParams
from pandas import DataFrame

from radvel import plot
from radvel.posterior import Posterior

"""
Module for plotting results of MCMC analysis, including:
    - trend plot
    - autocorrelation plot
    - corner plot of fitted parameters
    - corner plot of derived parameters
"""


[docs] class TrendPlot(object): """ Class to handle the creation of a trend plot to show the evolution of the MCMC as a function of step number. Args: post (radvel.Posterior): Radvel Posterior object chains (DataFrame): MCMC chains output by radvel.mcmc nwalkers (int): number of walkers used in this particular MCMC run outfile (string [optional]): name of output multi-page PDF file """ def __init__(self, post: Posterior, chains: DataFrame, nwalkers: int, nensembles: int, outfile: str | None = None): self.chains = chains self.outfile = outfile self.nwalkers = nwalkers self.nensembles = nensembles self.labels = sorted([k for k in post.params.keys() if post.params[k].vary]) self.texlabels = [post.params.tex_labels().get(l, l) for l in self.labels] self.colors = [plot.cmap(x) for x in np.linspace(0.05, 0.95, nwalkers)]
[docs] def plot(self) -> None: """ Make and save the trend plot as PDF """ with PdfPages(self.outfile) as pdf: for param, tex in zip(self.labels, self.texlabels): flatchain = self.chains[param].values wchain = flatchain.reshape((self.nwalkers, self.nensembles, -1)) _ = pl.figure(figsize=(18, 10)) for w in range(self.nwalkers): for e in range(self.nensembles): pl.plot( wchain[w][e], '.', rasterized=True, color=self.colors[w], markersize=4 ) pl.xlim(0, wchain.shape[2]) pl.xlabel('Step Number') try: pl.ylabel(tex) except ValueError: pl.ylabel(param) ax = pl.gca() ax.set_rasterized(True) pdf.savefig() pl.close() print("Trend plot saved to %s" % self.outfile)
[docs] class AutoPlot(object): """ Class to handle the creation of an autocorrelation time plot from output autocorrelation times. Args: auto (DataFrame): Autocorrelation times output by radvel.mcmc saveplot (str, optional): Name of output file, will show as interactive matplotlib window if not defined. """ def __init__(self, auto: DataFrame, saveplot: str | None = None): self.auto = auto self.saveplot = saveplot
[docs] def plot(self) -> None: """ Make and either save or display the autocorrelation plot """ fig = pl.figure(figsize=(6, 4)) pl.scatter(self.auto['autosamples'], self.auto['automin'], color = 'blue', label='Minimum Autocorrelation Time') pl.scatter(self.auto['autosamples'], self.auto['automean'], color = 'black', label='Mean Autocorrelation Time') pl.scatter(self.auto['autosamples'], self.auto['automax'], color = 'red', label='Maximum Autocorrelation Time') pl.plot(self.auto['autosamples'], self.auto['autosamples']/self.auto['factor'][0], linestyle=':', color='gray', label='Autocorrelation Factor Criterion (N/{})'.format(self.auto['factor'][0])) if self.auto['autosamples'].min() != self.auto['autosamples'].max(): pl.xlim(self.auto['autosamples'].min(), self.auto['autosamples'].max()) if (self.auto['autosamples']/self.auto['factor']).max() > self.auto['automax'].max(): pl.ylim(self.auto['automin'].min(), (self.auto['autosamples']/self.auto['factor']).max()) else: pl.ylim(self.auto['automin'].min(), self.auto['automax'].max()) pl.xlabel('Steps per Parameter') pl.ylabel('Autocorrelation Time') pl.legend() fig.tight_layout() if self.saveplot is not None: fig.savefig(self.saveplot, dpi=150) print("Auto plot saved to %s" % self.saveplot) else: fig.show()
[docs] class CornerPlot(object): """ Class to handle the creation of a corner plot from output MCMC chains and a posterior object. Args: post (radvel.Posterior): radvel posterior object chains (DataFrame): MCMC chains output by radvel.mcmc saveplot (str, optional): Name of output file, will show as interactive matplotlib window if not defined. """ def __init__(self, post: Posterior, chains: DataFrame, saveplot: str | None = None) -> None: self.post = post self.chains = chains self.saveplot = saveplot self.labels = [k for k in post.params.keys() if post.params[k].vary] self.texlabels = [post.params.tex_labels().get(l, l) for l in self.labels]
[docs] def plot(self) -> None: """ Make and either save or display the corner plot """ f = rcParams['font.size'] rcParams['font.size'] = 12 _ = corner.corner( self.chains[self.labels], labels=self.texlabels, label_kwargs={"fontsize": 14}, plot_datapoints=False, bins=30, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 14}, smooth=True ) if self.saveplot is not None: pl.savefig(self.saveplot, dpi=150) print("Corner plot saved to %s" % self.saveplot) else: pl.show() rcParams['font.size'] = f
[docs] class DerivedPlot(object): """ Class to handle the creation of a corner plot of derived parameters from output MCMC chains and a posterior object. Args: chains (DataFrame): MCMC chains output by radvel.mcmc P: object representation of config file saveplot (Optional[string]: Name of output file, will show as interactive matplotlib window if not defined. """ def __init__(self, chains: DataFrame, P: ModuleType, saveplot: str | None = None) -> None: self.chains = chains self.saveplot = saveplot if 'planet_letters' in dir(P): planet_letters = P.planet_letters else: planet_letters = {1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k'} # Determine which columns to include in corner plot self.labels = [] self.texlabels = [] self.units = [] for i in np.arange(1, P.nplanets + 1, 1): letter = planet_letters[i] for key in 'mpsini rhop a'.split(): label = '{}{}'.format(key, i) is_column = list(self.chains.columns).count(label) == 1 if not is_column: continue null_column = self.chains.isnull().any().loc[label] if null_column: continue tl = texlabel(label, letter) # add units to label if key == 'mpsini': unit = "M$_{\\oplus}$" if np.median(self.chains[label]) > 100: unit = "M$_{\\rm Jup}$" self.chains[label] *= 0.00315 elif np.median(self.chains[label]) > 100: unit = "M$_{\\odot}$" self.chains[label] *= 0.000954265748 elif key == 'rhop': unit = " g cm$^{-3}$" elif key == 'a': unit = " AU" else: unit = " " self.units.append(unit) self.labels.append(label) self.texlabels.append(tl)
[docs] def plot(self) -> None: """ Make and either save or display the corner plot """ f = rcParams['font.size'] rcParams['font.size'] = 12 plot_labels = [] for t, u in zip(self.texlabels, self.units): label = '{} [{}]'.format(t, u) plot_labels.append(label) _ = corner.corner( self.chains[self.labels], labels=plot_labels, label_kwargs={"fontsize": 14}, plot_datapoints=False, bins=30, quantiles=[0.16, 0.50, 0.84], show_titles=True, title_kwargs={"fontsize": 14}, smooth=True ) if self.saveplot is not None: pl.savefig(self.saveplot, dpi=150) print("Derived plot saved to %s" % self.saveplot) else: pl.show() rcParams['font.size'] = f
[docs] def texlabel(key: str, letter: str) -> str: """ Args: key (string): list of parameter strings letter (string): planet letter Returns: string: LaTeX label for parameter string """ if key.count('mpsini') == 1: return '$M_' + letter + '\\sin i$' if key.count('rhop') == 1: return '$\\rho_' + letter + '$' if key.count('a') == 1: return "$a_" + letter + "$"