Source code for pystan.misc

"""PyStan utility functions

These functions validate and organize data passed to and from the
classes and functions defined in the file `stan_fit.hpp` and wrapped
by the Cython file `stan_fit.pxd`.

"""
#-----------------------------------------------------------------------------
# Copyright (c) 2013-2015, PyStan developers
#
# This file is licensed under Version 3.0 of the GNU General Public
# License. See LICENSE for a text of the license.
#-----------------------------------------------------------------------------

# REF: rstan/rstan/R/misc.R

from __future__ import unicode_literals, division
from pystan._compat import PY2, string_types

from collections import OrderedDict
if PY2:
    from collections import Callable, Iterable, Sequence
else:
    from collections.abc import Callable, Iterable, Sequence
import inspect
import io
import itertools
import logging
import math
from numbers import Number
import os
import random
import re
import sys
import shutil
import tempfile
import time

import numpy as np

# READTHEDOCS needs this change
from pystan.external.scipy.mstats import mquantiles

import pystan.chains
import pystan._misc
from pystan.constants import (MAX_UINT, sampling_algo_t, optim_algo_t,
                              variational_algo_t, sampling_metric_t, stan_args_method_t)

logger = logging.getLogger('pystan')


def stansummary(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
    """
    Summary statistic table.

    Parameters
    ----------
    fit : StanFit4Model object
    pars : str or sequence of str, optional
        Parameter names. By default use all parameters
    probs : sequence of float, optional
        Quantiles. By default, (0.025, 0.25, 0.5, 0.75, 0.975)
    digits_summary : int, optional
        Number of significant digits. By default, 2

    Returns
    -------
    summary : string
        Table includes mean, se_mean, sd, probs_0, ..., probs_n, n_eff and Rhat.

    Examples
    --------
    >>> model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
    >>> m = StanModel(model_code=model_code, model_name="example_model")
    >>> fit = m.sampling()
    >>> print(stansummary(fit))
    Inference for Stan model: example_model.
    4 chains, each with iter=2000; warmup=1000; thin=1;
    post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
    y      0.01    0.03    1.0  -2.01  -0.68   0.02   0.72   1.97   1330    1.0
    lp__   -0.5    0.02   0.68  -2.44  -0.66  -0.24  -0.05-5.5e-4   1555    1.0

    Samples were drawn using NUTS at Thu Aug 17 00:52:25 2017.
    For each parameter, n_eff is a crude measure of effective sample size,
    and Rhat is the potential scale reduction factor on split chains (at
    convergence, Rhat=1).
    """
    if fit.mode == 1:
        return "Stan model '{}' is of mode 'test_grad';\n"\
               "sampling is not conducted.".format(fit.model_name)
    elif fit.mode == 2:
        return "Stan model '{}' does not contain samples.".format(fit.model_name)

    n_kept = [s - w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])]
    header = "Inference for Stan model: {}.\n".format(fit.model_name)
    header += "{} chains, each with iter={}; warmup={}; thin={}; \n"
    header = header.format(fit.sim['chains'], fit.sim['iter'], fit.sim['warmup'],
                           fit.sim['thin'], sum(n_kept))
    header += "post-warmup draws per chain={}, total post-warmup draws={}.\n\n"
    header = header.format(n_kept[0], sum(n_kept))
    footer = "\n\nSamples were drawn using {} at {}.\n"\
        "For each parameter, n_eff is a crude measure of effective sample size,\n"\
        "and Rhat is the potential scale reduction factor on split chains (at \n"\
        "convergence, Rhat=1)."
    sampler = fit.sim['samples'][0]['args']['sampler_t']
    date = fit.date.strftime('%c')  # %c is locale's representation
    footer = footer.format(sampler, date)
    s = _summary(fit, pars, probs)
    body = _array_to_table(s['summary'], s['summary_rownames'],
                           s['summary_colnames'], digits_summary)
    return header + body + footer

def _print_stanfit(fit, pars=None, probs=(0.025, 0.25, 0.5, 0.75, 0.975), digits_summary=2):
    # warning added in PyStan 2.17.0
    logger.warning('Function `_print_stanfit` is deprecated and will be removed in a future version. '\
                  'Use `stansummary` instead.', DeprecationWarning)
    return stansummary(fit, pars=pars, probs=probs, digits_summary=digits_summary)

def _array_to_table(arr, rownames, colnames, n_digits):
    """Print an array with row and column names

    Example:
                  mean se_mean  sd 2.5%  25%  50%  75% 97.5% n_eff Rhat
        beta[1,1]  0.0     0.0 1.0 -2.0 -0.7  0.0  0.7   2.0  4000    1
        beta[1,2]  0.0     0.0 1.0 -2.1 -0.7  0.0  0.7   2.0  4000    1
        beta[2,1]  0.0     0.0 1.0 -2.0 -0.7  0.0  0.7   2.0  4000    1
        beta[2,2]  0.0     0.0 1.0 -1.9 -0.6  0.0  0.7   2.0  4000    1
        lp__      -4.2     0.1 2.1 -9.4 -5.4 -3.8 -2.7  -1.2   317    1
    """
    assert arr.shape == (len(rownames), len(colnames))
    rownames_maxwidth = max(len(n) for n in rownames)
    max_col_width = 7
    min_col_width = 5
    max_col_header_num_width = [max(max_col_width, max(len(n) + 1, min_col_width)) for n in colnames]
    rows = []
    for row in arr:
        row_nums = []
        for j, (num, width) in enumerate(zip(row, max_col_header_num_width)):
            if colnames[j] == "n_eff":
                num = int(round(num, 0)) if not np.isnan(num) else num
            num = _format_number(num, n_digits, max_col_width - 1)
            row_nums.append(num)
            if len(num) + 1 > max_col_header_num_width[j]:
                max_col_header_num_width[j] = len(num) + 1
        rows.append(row_nums)
    widths = [rownames_maxwidth] + max_col_header_num_width
    header = '{:>{width}}'.format('', width=widths[0])
    for name, width in zip(colnames, widths[1:]):
        header += '{name:>{width}}'.format(name=name, width=width)
    lines = [header]
    for rowname, row in zip(rownames, rows):
        line = '{name:{width}}'.format(name=rowname, width=widths[0])
        for j, (num, width) in enumerate(zip(row, widths[1:])):
            line += '{num:>{width}}'.format(num=num, width=width)
        lines.append(line)
    return '\n'.join(lines)


def _number_width(n):
    """Calculate the width in characters required to print a number

    For example, -1024 takes 5 characters. -0.034 takes 6 characters.
    """
    return len(str(n))


def _format_number_si(num, n_signif_figures):
    """Format a number using scientific notation to given significant figures"""
    if math.isnan(num) or math.isinf(num):
        return str(num)
    leading, exp = '{:E}'.format(num).split('E')
    leading = round(float(leading), n_signif_figures - 1)
    exp = exp[:1] + exp[2:] if exp[1] == '0' else exp
    formatted = '{}e{}'.format(leading, exp.lstrip('+'))
    return formatted


def _format_number(num, n_signif_figures, max_width):
    """Format a number as a string while obeying space constraints.

    `n_signif_figures` is the minimum number of significant figures expressed
    `max_width` is the maximum width in characters allowed
    """
    if max_width < 6:
        raise NotImplementedError("Guaranteed formatting in fewer than 6 characters not supported.")
    if math.isnan(num) or math.isinf(num):
        return str(num)
    # add 0.5 to prevent log(0) errors; only affects n_digits calculation for num > 0
    n_digits = lambda num: math.floor(math.log10(abs(num) + 0.5)) + 1
    if abs(num) > 10**-n_signif_figures and n_digits(num) <= max_width - n_signif_figures:
        return str(round(num, n_signif_figures))[:max_width].rstrip('.')
    elif _number_width(num) <= max_width:
        if n_digits(num) >= n_signif_figures:
            # the int() is necessary for consistency between Python 2 and 3
            return str(int(round(num)))
        else:
            return str(num)
    else:
        return _format_number_si(num, n_signif_figures)


def _summary(fit, pars=None, probs=None, **kwargs):
    """Summarize samples (compute mean, SD, quantiles) in all chains.

    REF: stanfit-class.R summary method

    Parameters
    ----------
    fit : StanFit4Model object
    pars : str or sequence of str, optional
        Parameter names. By default use all parameters
    probs : sequence of float, optional
        Quantiles. By default, (0.025, 0.25, 0.5, 0.75, 0.975)

    Returns
    -------
    summaries : OrderedDict of array
        Array indexed by 'summary' has dimensions (num_params, num_statistics).
        Parameters are unraveled in *row-major order*. Statistics include: mean,
        se_mean, sd, probs_0, ..., probs_n, n_eff, and Rhat. Array indexed by
        'c_summary' breaks down the statistics by chain and has dimensions
        (num_params, num_statistics_c_summary, num_chains). Statistics for
        `c_summary` are the same as for `summary` with the exception that
        se_mean, n_eff, and Rhat are absent. Row names and column names are
        also included in the OrderedDict.
    """
    if fit.mode == 1:
        msg = "Stan model {} is of mode 'test_grad'; sampling is not conducted."
        msg = msg.format(fit.model_name)
        raise ValueError(msg)
    elif fit.mode == 2:
        msg = "Stan model {} contains no samples.".format(fit.model_name)
        raise ValueError(msg)

    if fit.sim['n_save'] == fit.sim['warmup2']:
        msg = "Stan model {} contains no samples.".format(fit.model_name)
        raise ValueError(msg)

    # rstan checks for cached summaries here

    if pars is None:
        pars = fit.sim['pars_oi']
    elif isinstance(pars, string_types):
        pars = [pars]
    pars = _remove_empty_pars(pars, fit.sim['pars_oi'], fit.sim['dims_oi'])

    if probs is None:
        probs = (0.025, 0.25, 0.5, 0.75, 0.975)
    ss = _summary_sim(fit.sim, pars, probs)
    # TODO: include sem, ess and rhat: ss['ess'], ss['rhat']
    s1 = np.column_stack([ss['msd'][:, 0], ss['sem'], ss['msd'][:, 1], ss['quan'], ss['ess'], ss['rhat']])
    s1_rownames = ss['c_msd_names']['parameters']
    s1_colnames = ((ss['c_msd_names']['stats'][0],) + ('se_mean',) +
                   (ss['c_msd_names']['stats'][1],) + ss['c_quan_names']['stats'] +
                   ('n_eff', 'Rhat'))
    s2 = _combine_msd_quan(ss['c_msd'], ss['c_quan'])
    s2_rownames = ss['c_msd_names']['parameters']
    s2_colnames = ss['c_msd_names']['stats'] + ss['c_quan_names']['stats']
    return OrderedDict(summary=s1, c_summary=s2,
                       summary_rownames=s1_rownames,
                       summary_colnames=s1_colnames,
                       c_summary_rownames=s2_rownames,
                       c_summary_colnames=s2_colnames)


def _combine_msd_quan(msd, quan):
    """Combine msd and quantiles in chain summary

    Parameters
    ----------
    msd : array of shape (num_params, 2, num_chains)
       mean and sd for chains
    cquan : array of shape (num_params, num_quan, num_chains)
        quantiles for chains

    Returns
    -------
    msdquan : array of shape (num_params, 2 + num_quan, num_chains)
    """
    dim1 = msd.shape
    n_par, _, n_chains = dim1
    ll = []
    for i in range(n_chains):
        a1 = msd[:, :, i]
        a2 = quan[:, :, i]
        ll.append(np.column_stack([a1, a2]))
    msdquan = np.dstack(ll)
    return msdquan


def _summary_sim(sim, pars, probs):
    """Summarize chains together and separately

    REF: rstan/rstan/R/misc.R

    Parameters are unraveled in *column-major order*.

    Parameters
    ----------
    sim : dict
        dict from from a stanfit fit object, i.e., fit['sim']
    pars : Iterable of str
        parameter names
    probs : Iterable of probs
        desired quantiles

    Returns
    -------
    summaries : OrderedDict of array
        This dictionary contains the following arrays indexed by the keys
        given below:
        - 'msd' : array of shape (num_params, 2) with mean and sd
        - 'sem' : array of length num_params with standard error for the mean
        - 'c_msd' : array of shape (num_params, 2, num_chains)
        - 'quan' : array of shape (num_params, num_quan)
        - 'c_quan' : array of shape (num_params, num_quan, num_chains)
        - 'ess' : array of shape (num_params, 1)
        - 'rhat' : array of shape (num_params, 1)

    Note
    ----
    `_summary_sim` has the parameters in *column-major* order whereas `_summary`
    gives them in *row-major* order. (This follows RStan.)
    """
    # NOTE: this follows RStan rather closely. Some of the calculations here
    probs_len = len(probs)
    n_chains = len(sim['samples'])
    # tidx is a dict with keys that are parameters and values that are their
    # indices using column-major ordering
    tidx = _pars_total_indexes(sim['pars_oi'], sim['dims_oi'], sim['fnames_oi'], pars)
    tidx_colm = [tidx[par] for par in pars]
    tidx_colm = list(itertools.chain(*tidx_colm))  # like R's unlist()
    tidx_rowm = [tidx[par+'_rowmajor'] for par in pars]
    tidx_rowm = list(itertools.chain(*tidx_rowm))
    tidx_len = len(tidx_colm)
    lmsdq = [_get_par_summary(sim, i, probs) for i in tidx_colm]
    msd = np.row_stack([x['msd'] for x in lmsdq])
    quan = np.row_stack([x['quan'] for x in lmsdq])
    probs_str = tuple(["{:g}%".format(100*p) for p in probs])
    msd = msd.reshape(tidx_len, 2, order='F')
    quan = quan.reshape(tidx_len, probs_len, order='F')

    c_msd = np.row_stack([x['c_msd'] for x in lmsdq])
    c_quan = np.row_stack([x['c_quan'] for x in lmsdq])
    c_msd = c_msd.reshape(tidx_len, 2, n_chains, order='F')
    c_quan = c_quan.reshape(tidx_len, probs_len, n_chains, order='F')
    sim_attr_args = sim.get('args', None)
    if sim_attr_args is None:
        cids = list(range(n_chains))
    else:
        cids = [x['chain_id'] for x in sim_attr_args]

    c_msd_names = dict(parameters=np.asarray(sim['fnames_oi'])[tidx_colm],
                       stats=("mean", "sd"),
                       chains=tuple("chain:{}".format(cid) for cid in cids))
    c_quan_names = dict(parameters=np.asarray(sim['fnames_oi'])[tidx_colm],
                        stats=probs_str,
                        chains=tuple("chain:{}".format(cid) for cid in cids))
    ess_and_rhat = np.array([pystan.chains.ess_and_splitrhat(sim, n) for n in tidx_colm])
    ess, rhat = [arr.ravel() for arr in np.hsplit(ess_and_rhat, 2)]
    return dict(msd=msd, c_msd=c_msd, c_msd_names=c_msd_names, quan=quan,
                c_quan=c_quan, c_quan_names=c_quan_names,
                sem=msd[:, 1] / np.sqrt(ess), ess=ess, rhat=rhat,
                row_major_idx=tidx_rowm, col_major_idx=tidx_colm)


def _get_par_summary(sim, n, probs):
    """Summarize chains merged and individually

    Parameters
    ----------
    sim : dict from stanfit object
    n : int
        parameter index
    probs : iterable of int
        quantiles

    Returns
    -------
    summary : dict
       Dictionary containing summaries
    """
    # _get_samples gets chains for nth parameter
    ss = _get_samples(n, sim, inc_warmup=False)
    msdfun = lambda chain: (np.mean(chain), np.std(chain, ddof=1))
    qfun = lambda chain: mquantiles(chain, probs)
    c_msd = np.array([msdfun(s) for s in ss]).flatten()
    c_quan = np.array([qfun(s) for s in ss]).flatten()
    ass = np.asarray(ss).flatten()
    msd = np.asarray(msdfun(ass))
    quan = qfun(np.asarray(ass))
    return dict(msd=msd, quan=quan, c_msd=c_msd, c_quan=c_quan)


def _split_data(data):
    data_r = {}
    data_i = {}
    # data_r and data_i are going to be converted into C++ objects of
    # type: map<string, pair<vector<double>, vector<size_t>>> and
    # map<string, pair<vector<int>, vector<size_t>>> so prepare
    # them accordingly.
    for k, v in data.items():
        if np.issubdtype(np.asarray(v).dtype, np.integer):
            data_i.update({k.encode('utf-8'): np.asarray(v, dtype=int)})
        elif np.issubdtype(np.asarray(v).dtype, np.floating):
            data_r.update({k.encode('utf-8'): np.asarray(v, dtype=float)})
        else:
            msg = "Variable {} is neither int nor float nor list/array thereof"
            raise ValueError(msg.format(k))
    return data_r, data_i


def _config_argss(chains, iter, warmup, thin,
                  init, seed, sample_file, diagnostic_file, algorithm,
                  control, **kwargs):
    # After rstan/rstan/R/misc.R (config_argss)
    iter = int(iter)
    if iter < 1:
        raise ValueError("`iter` should be a positive integer.")
    thin = int(thin)
    if thin < 1 or thin > iter:
        raise ValueError("`thin should be a positive integer "
                         "less than `iter`.")
    warmup = max(0, int(warmup))
    if warmup > iter:
        raise ValueError("`warmup` should be an integer less than `iter`.")
    chains = int(chains)
    if chains < 1:
        raise ValueError("`chains` should be a positive integer.")

    iters = [iter] * chains
    thins = [thin] * chains
    warmups = [warmup] * chains

    # use chain_id argument if specified
    if kwargs.get('chain_id') is None:
        chain_id = list(range(chains))
    else:
        chain_id = [int(id) for id in kwargs['chain_id']]
        if len(set(chain_id)) != len(chain_id):
            raise ValueError("`chain_id` has duplicated elements.")
        chain_id_len = len(chain_id)
        if chain_id_len >= chains:
            chain_id = chain_id
        else:
            chain_id = chain_id + [max(chain_id) + 1 + i
                                   for i in range(chains - chain_id_len)]
        del kwargs['chain_id']

    inits_specified = False
    # slight difference here from rstan; Python's lists are not typed.
    if isinstance(init, Number):
        init = str(init)
    if isinstance(init, string_types):
        if init in ['0', 'random']:
            inits = [init] * chains
        else:
            inits = ["random"] * chains
        inits_specified = True
    if not inits_specified and isinstance(init, Callable):
        ## test if function takes argument named "chain_id"
        if "chain_id" in inspect.getargspec(init).args:
            inits = [init(chain_id=id) for id in chain_id]
        else:
            inits = [init()] * chains
        if not isinstance(inits[0], dict):
            raise ValueError("The function specifying initial values must "
                             "return a dictionary.")
        inits_specified = True
    if not inits_specified and isinstance(init, Sequence):
        if len(init) != chains:
            raise ValueError("Length of list of initial values does not "
                             "match number of chains.")
        if not all([isinstance(d, dict) for d in init]):
            raise ValueError("Initial value list is not a sequence of "
                             "dictionaries.")
        inits = init
        inits_specified = True
    if not inits_specified:
        raise ValueError("Invalid specification of initial values.")

    ## only one seed is needed by virtue of the RNG
    seed = _check_seed(seed)

    kwargs['method'] = "test_grad" if kwargs.get('test_grad') else 'sampling'

    all_control = {
        "adapt_engaged", "adapt_gamma", "adapt_delta", "adapt_kappa",
        "adapt_t0", "adapt_init_buffer", "adapt_term_buffer", "adapt_window",
        "stepsize", "stepsize_jitter", "metric", "int_time",
        "max_treedepth", "epsilon", "error", "inv_metric"
    }
    all_metrics = {"unit_e", "diag_e", "dense_e"}

    if control is not None:
        if not isinstance(control, dict):
            raise ValueError("`control` must be a dictionary")
        if not all(key in all_control for key in control):
            unknown = set(control) - all_control
            raise ValueError("`control` contains unknown parameters: {}".format(unknown))
        if control.get('metric') and control['metric'] not in all_metrics:
            raise ValueError("`metric` must be one of {}".format(all_metrics))
        kwargs['control'] = control

    argss = [dict() for _ in range(chains)]
    for i in range(chains):
        argss[i] = dict(chain_id=chain_id[i],
                        iter=iters[i], thin=thins[i], seed=seed,
                        warmup=warmups[i], init=inits[i],
                        algorithm=algorithm)

    if sample_file is not None:
        sample_file = _writable_sample_file(sample_file)
        if chains == 1:
            argss[0]['sample_file'] = sample_file
        elif chains > 1:
            for i in range(chains):
                argss[i]['sample_file'] = _append_id(sample_file, i)

    if diagnostic_file is not None:
        raise NotImplementedError("diagnostic_file not implemented yet.")

    if control is not None and "inv_metric" in control:
        inv_metric = control.pop("inv_metric")
        metric_dir = tempfile.mkdtemp()
        if isinstance(inv_metric, dict):
            for i in range(chains):
                if i not in inv_metric:
                    msg = "Invalid value for init_inv_metric found (keys={}). " \
                          "Use either a dictionary with chain_index as keys (0,1,2,...)" \
                          "or ndarray."
                    msg = msg.format(list(metric_file.keys()))
                    raise ValueError(msg)
                mass_values = inv_metric[i]
                metric_filename = "inv_metric_chain_{}.Rdata".format(str(i))
                metric_path = os.path.join(metric_dir, metric_filename)
                if isinstance(mass_values, str):
                    if not os.path.exists(mass_values):
                        raise ValueError("inverse metric file was not found: {}".format(mass_values))
                    shutil.copy(mass_values, metric_path)
                else:
                    stan_rdump(dict(inv_metric=mass_values), metric_path)
                argss[i]['metric_file'] = metric_path
        elif isinstance(inv_metric, str):
            if not os.path.exists(inv_metric):
                raise ValueError("inverse metric  file was not found: {}".format(inv_metric))
            for i in range(chains):
                metric_filename = "inv_metric_chain_{}.Rdata".format(str(i))
                metric_path = os.path.join(metric_dir, metric_filename)
                shutil.copy(inv_metric, metric_path)
                argss[i]['metric_file'] = metric_path
        elif isinstance(inv_metric, Iterable):
            metric_filename = "inv_metric_chain_0.Rdata"
            metric_path = os.path.join(metric_dir, metric_filename)
            stan_rdump(dict(inv_metric=inv_metric), metric_path)
            argss[0]['metric_file'] = metric_path
            for i in range(1, chains):
                metric_filename = "inv_metric_chain_{}.Rdata".format(str(i))
                metric_path = os.path.join(metric_dir, metric_filename)
                shutil.copy(argss[i-1]['metric_file'], metric_path)
                argss[i]['metric_file'] = metric_path
        else:
            argss[i]['metric_file'] = ""

    stepsize_list = None
    if "control" in kwargs and "stepsize" in kwargs["control"]:
        if isinstance(kwargs["control"]["stepsize"], Sequence):
            stepsize_list = kwargs["control"]["stepsize"]
            if len(kwargs["control"]["stepsize"]) == 1:
                kwargs["control"]["stepsize"] = kwargs["control"]["stepsize"][0]
            elif len(kwargs["control"]["stepsize"]) != chains:
                raise ValueError("stepsize length needs to equal chain count.")
            else:
                stepsize_list = kwargs["control"]["stepsize"]

    for i in range(chains):
        argss[i].update(kwargs)
        if stepsize_list is not None:
            argss[i]["control"]["stepsize"] = stepsize_list[i]
        argss[i] = _get_valid_stan_args(argss[i])

    return argss


def _get_valid_stan_args(base_args=None):
    """Fill in default values for arguments not provided in `base_args`.

    RStan does this in C++ in stan_args.hpp in the stan_args constructor.
    It seems easier to deal with here in Python.

    """
    args = base_args.copy() if base_args is not None else {}
    # Default arguments, c.f. rstan/rstan/inst/include/rstan/stan_args.hpp
    # values in args are going to be converted into C++ objects so
    # prepare them accordingly---e.g., unicode -> bytes -> std::string
    args['chain_id'] = args.get('chain_id', 1)
    args['append_samples'] = args.get('append_samples', False)
    if args.get('method') is None or args['method'] == "sampling":
        args['method'] = stan_args_method_t.SAMPLING
    elif args['method'] == "optim":
        args['method'] = stan_args_method_t.OPTIM
    elif args['method'] == 'test_grad':
        args['method'] = stan_args_method_t.TEST_GRADIENT
    elif args['method'] == 'variational':
        args['method'] = stan_args_method_t.VARIATIONAL
    else:
        args['method'] = stan_args_method_t.SAMPLING
    args['sample_file_flag'] = True if args.get('sample_file') else False
    args['sample_file'] = args.get('sample_file', '').encode('ascii')
    args['diagnostic_file_flag'] = True if args.get('diagnostic_file') else False
    args['diagnostic_file'] = args.get('diagnostic_file', '').encode('ascii')
    # NB: argument named "seed" not "random_seed"
    args['random_seed'] = args.get('seed', int(time.time()))

    args['metric_file_flag'] = True if args.get('metric_file') else False
    args['metric_file'] = args.get('metric_file', '').encode('ascii')

    if args['method'] == stan_args_method_t.VARIATIONAL:
        # variational does not use a `control` map like sampling
        args['ctrl'] = args.get('ctrl', dict(variational=dict()))
        args['ctrl']['variational']['iter'] = args.get('iter', 10000)
        args['ctrl']['variational']['grad_samples'] = args.get('grad_samples', 1)
        args['ctrl']['variational']['elbo_samples'] = args.get('elbo_samples', 100)
        args['ctrl']['variational']['eval_elbo'] = args.get('eval_elbo', 100)
        args['ctrl']['variational']['output_samples'] = args.get('output_samples', 1000)
        args['ctrl']['variational']['adapt_iter'] = args.get('adapt_iter', 50)
        args['ctrl']['variational']['eta'] = args.get('eta', 1.0)
        args['ctrl']['variational']['adapt_engaged'] = args.get('adapt_engaged', True)
        args['ctrl']['variational']['tol_rel_obj'] = args.get('tol_rel_obj', 0.01)
        if args.get('algorithm', '').lower() == 'fullrank':
            args['ctrl']['variational']['algorithm'] = variational_algo_t.FULLRANK
        else:
            args['ctrl']['variational']['algorithm'] = variational_algo_t.MEANFIELD
    elif args['method'] == stan_args_method_t.SAMPLING:
        args['ctrl'] = args.get('ctrl', dict(sampling=dict()))
        args['ctrl']['sampling']['iter'] = iter = args.get('iter', 2000)
        args['ctrl']['sampling']['warmup'] = warmup = args.get('warmup', iter // 2)
        calculated_thin = iter - warmup // 1000
        if calculated_thin < 1:
            calculated_thin = 1
        args['ctrl']['sampling']['thin'] = thin = args.get('thin', calculated_thin)
        args['ctrl']['sampling']['save_warmup'] = True  # always True now
        args['ctrl']['sampling']['iter_save_wo_warmup'] = iter_save_wo_warmup = 1 + (iter - warmup - 1) // thin
        args['ctrl']['sampling']['iter_save'] = iter_save_wo_warmup + 1 + (warmup - 1) // thin
        refresh = iter // 10 if iter >= 20 else 1
        args['ctrl']['sampling']['refresh'] = args.get('refresh', refresh)

        ctrl_lst = args.get('control', dict())
        ctrl_sampling = args['ctrl']['sampling']
        # NB: if these defaults change, remember to update docstrings
        ctrl_sampling['adapt_engaged'] = ctrl_lst.get("adapt_engaged", True)
        ctrl_sampling['adapt_gamma'] = ctrl_lst.get("adapt_gamma", 0.05)
        ctrl_sampling['adapt_delta'] = ctrl_lst.get("adapt_delta", 0.8)
        ctrl_sampling['adapt_kappa'] = ctrl_lst.get("adapt_kappa", 0.75)
        ctrl_sampling['adapt_t0'] = ctrl_lst.get("adapt_t0", 10.0)
        ctrl_sampling['adapt_init_buffer'] = ctrl_lst.get("adapt_init_buffer", 75)
        ctrl_sampling['adapt_term_buffer'] = ctrl_lst.get("adapt_term_buffer", 50)
        ctrl_sampling['adapt_window'] = ctrl_lst.get("adapt_window", 25)
        ctrl_sampling['stepsize'] = ctrl_lst.get("stepsize", 1.0)
        ctrl_sampling['stepsize_jitter'] = ctrl_lst.get("stepsize_jitter", 0.0)

        algorithm = args.get('algorithm', 'NUTS')
        if algorithm == 'HMC':
            args['ctrl']['sampling']['algorithm'] = sampling_algo_t.HMC
        elif algorithm == 'Metropolis':
            args['ctrl']['sampling']['algorithm'] = sampling_algo_t.Metropolis
        elif algorithm == 'NUTS':
            args['ctrl']['sampling']['algorithm'] = sampling_algo_t.NUTS
        elif algorithm == 'Fixed_param':
            args['ctrl']['sampling']['algorithm'] = sampling_algo_t.Fixed_param
            # TODO: Setting adapt_engaged to False solves the segfault reported
            # in issue #200; find out why this hack is needed. RStan deals with
            # the setting elsewhere.
            ctrl_sampling['adapt_engaged'] = False
        else:
            msg = "Invalid value for parameter algorithm (found {}; " \
                "require HMC, Metropolis, NUTS, or Fixed_param).".format(algorithm)
            raise ValueError(msg)

        metric = ctrl_lst.get('metric', 'diag_e')
        if metric == "unit_e":
            ctrl_sampling['metric'] = sampling_metric_t.UNIT_E
        elif metric == "diag_e":
            ctrl_sampling['metric'] = sampling_metric_t.DIAG_E
        elif metric == "dense_e":
            ctrl_sampling['metric'] = sampling_metric_t.DENSE_E

        if ctrl_sampling['algorithm'] == sampling_algo_t.NUTS:
            ctrl_sampling['max_treedepth'] = ctrl_lst.get("max_treedepth", 10)
        elif ctrl_sampling['algorithm'] == sampling_algo_t.HMC:
            ctrl_sampling['int_time'] = ctrl_lst.get('int_time', 6.283185307179586476925286766559005768e+00)
        elif ctrl_sampling['algorithm'] == sampling_algo_t.Metropolis:
            pass
        elif ctrl_sampling['algorithm'] == sampling_algo_t.Fixed_param:
            pass

    elif args['method'] == stan_args_method_t.OPTIM:
        args['ctrl'] = args.get('ctrl', dict(optim=dict()))
        args['ctrl']['optim']['iter'] = iter = args.get('iter', 2000)
        algorithm = args.get('algorithm', 'LBFGS')
        if algorithm == "BFGS":
            args['ctrl']['optim']['algorithm'] = optim_algo_t.BFGS
        elif algorithm == "Newton":
            args['ctrl']['optim']['algorithm'] = optim_algo_t.Newton
        elif algorithm == "LBFGS":
            args['ctrl']['optim']['algorithm'] = optim_algo_t.LBFGS
        else:
            msg = "Invalid value for parameter algorithm (found {}; " \
                  "require (L)BFGS or Newton).".format(algorithm)
            raise ValueError(msg)
        refresh = args['ctrl']['optim']['iter'] // 100
        args['ctrl']['optim']['refresh'] = args.get('refresh', refresh)
        if args['ctrl']['optim']['refresh'] < 1:
            args['ctrl']['optim']['refresh'] = 1
        args['ctrl']['optim']['init_alpha'] = args.get("init_alpha", 0.001)
        args['ctrl']['optim']['tol_obj'] = args.get("tol_obj", 1e-12)
        args['ctrl']['optim']['tol_grad'] = args.get("tol_grad", 1e-8)
        args['ctrl']['optim']['tol_param'] = args.get("tol_param", 1e-8)
        args['ctrl']['optim']['tol_rel_obj'] = args.get("tol_rel_obj", 1e4)
        args['ctrl']['optim']['tol_rel_grad'] = args.get("tol_rel_grad", 1e7)
        args['ctrl']['optim']['save_iterations'] = args.get("save_iterations", True)
        args['ctrl']['optim']['history_size'] = args.get("history_size", 5)
    elif args['method'] == stan_args_method_t.TEST_GRADIENT:
        args['ctrl'] = args.get('ctrl', dict(test_grad=dict()))
        args['ctrl']['test_grad']['epsilon'] = args.get("epsilon", 1e-6)
        args['ctrl']['test_grad']['error'] = args.get("error", 1e-6)

    init = args.get('init', "random")
    if isinstance(init, string_types):
        args['init'] = init.encode('ascii')
    elif isinstance(init, dict):
        args['init'] = "user".encode('ascii')
        # while the name is 'init_list', it is a dict; the name comes from rstan,
        # where list elements can have names
        args['init_list'] = init
    else:
        args['init'] = "random".encode('ascii')

    args['init_radius'] = args.get('init_r', 2.0)
    if (args['init_radius'] <= 0):
        args['init'] = b"0"

    # 0 initialization requires init_radius = 0
    if (args['init'] == b"0" or args['init'] == 0):
        args['init_radius'] = 0.0

    args['enable_random_init'] = args.get('enable_random_init', True)
    # RStan calls validate_args() here
    return args


def _check_seed(seed):
    """If possible, convert `seed` into a valid form for Stan (an integer
    between 0 and MAX_UINT, inclusive). If not possible, use a random seed
    instead and raise a warning if `seed` was not provided as `None`.
    """
    if isinstance(seed, (Number, string_types)):
        try:
            seed = int(seed)
        except ValueError:
            logger.warning("`seed` must be castable to an integer")
            seed = None
        else:
            if seed < 0:
                logger.warning("`seed` may not be negative")
                seed = None
            elif seed > MAX_UINT:
                raise ValueError('`seed` is too large; max is {}'.format(MAX_UINT))
    elif isinstance(seed, np.random.RandomState):
        seed = seed.randint(0, MAX_UINT)
    elif seed is not None:
        logger.warning('`seed` has unexpected type')
        seed = None

    if seed is None:
        seed = random.randint(0, MAX_UINT)

    return seed


def _organize_inits(inits, pars, dims):
    """Obtain a list of initial values for each chain.

    The parameter 'lp__' will be removed from the chains.

    Parameters
    ----------
    inits : list
        list of initial values for each chain.
    pars : list of str
    dims : list of list of int
        from (via cython conversion) vector[vector[uint]] dims

    Returns
    -------
    inits : list of dict

    """
    try:
        idx_of_lp = pars.index('lp__')
        del pars[idx_of_lp]
        del dims[idx_of_lp]
    except ValueError:
        pass
    starts = _calc_starts(dims)
    return [_par_vector2dict(init, pars, dims, starts) for init in inits]


def _calc_starts(dims):
    """Calculate starting indexes

    Parameters
    ----------
    dims : list of list of int
        from (via cython conversion) vector[vector[uint]] dims

    Examples
    --------
    >>> _calc_starts([[8, 2], [5], [6, 2]])
    [0, 16, 21]

    """
    # NB: Python uses 0-indexing; R uses 1-indexing.
    l = len(dims)
    s = [np.prod(d) for d in dims]
    starts = np.cumsum([0] + s)[0:l].tolist()
    # coerce things into ints before returning
    return [int(i) for i in starts]


def _par_vector2dict(v, pars, dims, starts=None):
    """Turn a vector of samples into an OrderedDict according to param dims.

    Parameters
    ----------
    y : list of int or float
    pars : list of str
        parameter names
    dims : list of list of int
        list of dimensions of parameters

    Returns
    -------
    d : dict

    Examples
    --------
    >>> v = list(range(31))
    >>> dims = [[5], [5, 5], []]
    >>> pars = ['mu', 'Phi', 'eta']
    >>> _par_vector2dict(v, pars, dims)  # doctest: +ELLIPSIS
    OrderedDict([('mu', array([0, 1, 2, 3, 4])), ('Phi', array([[ 5, ...

    """
    if starts is None:
        starts = _calc_starts(dims)
    d = OrderedDict()
    for i in range(len(pars)):
        l = int(np.prod(dims[i]))
        start = starts[i]
        end = start + l
        y = np.asarray(v[start:end])
        if len(dims[i]) > 1:
            y = y.reshape(dims[i], order='F')  # 'F' = Fortran, column-major
        d[pars[i]] = y.squeeze() if y.shape == (1,) else y
    return d


def _check_pars(allpars, pars):
    if len(pars) == 0:
        raise ValueError("No parameter specified (`pars` is empty).")
    for par in pars:
        if par not in allpars:
            raise ValueError("No parameter {}".format(par))


def _pars_total_indexes(names, dims, fnames, pars):
    """Obtain all the indexes for parameters `pars` in the sequence of names.

    `names` references variables that are in column-major order

    Parameters
    ----------
    names : sequence of str
        All the parameter names.
    dim : sequence of list of int
        Dimensions, in same order as `names`.
    fnames : sequence of str
        All the scalar parameter names
    pars : sequence of str
        The parameters of interest. It is assumed all elements in `pars` are in
        `names`.

    Returns
    -------
    indexes : OrderedDict of list of int
        Dictionary uses parameter names as keys. Indexes are column-major order.
        For each parameter there is also a key `par`+'_rowmajor' that stores the
        row-major indexing.

    Note
    ----
    Inside each parameter (vector or array), the sequence uses column-major
    ordering. For example, if we have parameters alpha and beta, having
    dimensions [2, 2] and [2, 3] respectively, the whole parameter sequence
    is alpha[0,0], alpha[1,0], alpha[0, 1], alpha[1, 1], beta[0, 0],
    beta[1, 0], beta[0, 1], beta[1, 1], beta[0, 2], beta[1, 2]. In short,
    like R matrix(..., bycol=TRUE).

    Example
    -------
    >>> pars_oi = ['mu', 'tau', 'eta', 'theta', 'lp__']
    >>> dims_oi = [[], [], [8], [8], []]
    >>> fnames_oi = ['mu', 'tau', 'eta[1]', 'eta[2]', 'eta[3]', 'eta[4]',
    ... 'eta[5]', 'eta[6]', 'eta[7]', 'eta[8]', 'theta[1]', 'theta[2]',
    ... 'theta[3]', 'theta[4]', 'theta[5]', 'theta[6]', 'theta[7]',
    ... 'theta[8]', 'lp__']
    >>> pars = ['mu', 'tau', 'eta', 'theta', 'lp__']
    >>> _pars_total_indexes(pars_oi, dims_oi, fnames_oi, pars)
    ... # doctest: +ELLIPSIS
    OrderedDict([('mu', (0,)), ('tau', (1,)), ('eta', (2, 3, ...

    """
    starts = _calc_starts(dims)

    def par_total_indexes(par):
        # if `par` is a scalar, it will match one of `fnames`
        if par in fnames:
            p = fnames.index(par)
            idx = tuple([p])
            return OrderedDict([(par, idx), (par+'_rowmajor', idx)])
        else:
            p = names.index(par)
            idx = starts[p] + np.arange(np.prod(dims[p]))
            idx_rowmajor = starts[p] + _idx_col2rowm(dims[p])
        return OrderedDict([(par, tuple(idx)), (par+'_rowmajor', tuple(idx_rowmajor))])

    indexes = OrderedDict()
    for par in pars:
        indexes.update(par_total_indexes(par))
    return indexes


def _idx_col2rowm(d):
    """Generate indexes to change from col-major to row-major ordering"""
    if 0 == len(d):
        return 1
    if 1 == len(d):
        return np.arange(d[0])
    # order='F' indicates column-major ordering
    idx = np.array(np.arange(np.prod(d))).reshape(d, order='F').T
    return idx.flatten(order='F')


def _get_kept_samples(n, sim):
    """Get samples to be kept from the chain(s) for `n`th parameter.

    Samples from different chains are merged.

    Parameters
    ----------
    n : int
    sim : dict
        A dictionary tied to a StanFit4Model instance.

    Returns
    -------
    samples : array
        Samples being kept, permuted and in column-major order.

    """
    return pystan._misc.get_kept_samples(n, sim)


def _get_samples(n, sim, inc_warmup=True):
    # NOTE: this is in stanfit-class.R in RStan (rather than misc.R)
    """Get chains for `n`th parameter.

    Parameters
    ----------
    n : int
    sim : dict
        A dictionary tied to a StanFit4Model instance.

    Returns
    -------
    chains : list of array
        Each chain is an element in the list.

    """
    return pystan._misc.get_samples(n, sim, inc_warmup)


def _redirect_stderr():
    """Redirect stderr for subprocesses to /dev/null

    Silences copious compilation messages.

    Returns
    -------
    orig_stderr : file descriptor
        Copy of original stderr file descriptor
    """
    sys.stderr.flush()
    stderr_fileno = sys.stderr.fileno()
    orig_stderr = os.dup(stderr_fileno)
    devnull = os.open(os.devnull, os.O_WRONLY)
    os.dup2(devnull, stderr_fileno)
    os.close(devnull)
    return orig_stderr


def _has_fileno(stream):
    """Returns whether the stream object seems to have a working fileno()

    Tells whether _redirect_stderr is likely to work.

    Parameters
    ----------
    stream : IO stream object

    Returns
    -------
    has_fileno : bool
        True if stream.fileno() exists and doesn't raise OSError or
        UnsupportedOperation
    """
    try:
        stream.fileno()
    except (AttributeError, OSError, IOError, io.UnsupportedOperation):
        return False
    return True


def _append_id(file, id, suffix='.csv'):
    fname = os.path.basename(file)
    fpath = os.path.dirname(file)
    fname2 = re.sub(r'\.csv\s*$', '_{}.csv'.format(id), fname)
    if fname2 == fname:
        fname2 = '{}_{}.csv'.format(fname, id)
    return os.path.join(fpath, fname2)


def _writable_sample_file(file, warn=True, wfun=None):
    """Check to see if file is writable, if not use temporary file"""
    if wfun is None:
        wfun = lambda x, y: '"{}" is not writable; use "{}" instead'.format(x, y)
    dir = os.path.dirname(file)
    dir = os.getcwd() if dir == '' else dir
    if os.access(dir, os.W_OK):
        return file
    else:
        dir2 = tempfile.mkdtemp()
        if warn:
            logger.warning(wfun(dir, dir2))
        return os.path.join(dir2, os.path.basename(file))


def is_legal_stan_vname(name):
    stan_kw1 = ('for', 'in', 'while', 'repeat', 'until', 'if', 'then', 'else',
                'true', 'false')
    stan_kw2 = ('int', 'real', 'vector', 'simplex', 'ordered', 'positive_ordered',
                'row_vector', 'matrix', 'corr_matrix', 'cov_matrix', 'lower', 'upper')
    stan_kw3 = ('model', 'data', 'parameters', 'quantities', 'transformed', 'generated')
    cpp_kw = ("alignas", "alignof", "and", "and_eq", "asm", "auto", "bitand", "bitor", "bool",
              "break", "case", "catch", "char", "char16_t", "char32_t", "class", "compl",
              "const", "constexpr", "const_cast", "continue", "decltype", "default", "delete",
              "do", "double", "dynamic_cast", "else", "enum", "explicit", "export", "extern",
              "false", "float", "for", "friend", "goto", "if", "inline", "int", "long", "mutable",
              "namespace", "new", "noexcept", "not", "not_eq", "nullptr", "operator", "or", "or_eq",
              "private", "protected", "public", "register", "reinterpret_cast", "return",
              "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
              "switch", "template", "this", "thread_local", "throw", "true", "try", "typedef",
              "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile",
              "wchar_t", "while", "xor", "xor_eq")
    illegal = stan_kw1 + stan_kw2 + stan_kw3 + cpp_kw
    if re.findall(r'(\.|^[0-9]|__$)', name):
        return False
    return not name in illegal


def _dict_to_rdump(data):
    parts = []
    for name, value in data.items():
        if isinstance(value, (Sequence, Number, np.number, np.ndarray, int, bool, float)) \
           and not isinstance(value, string_types):
            value = np.asarray(value)
        else:
            raise ValueError("Variable {} is not a number and cannot be dumped.".format(name))

        if value.dtype == np.bool:
            value = value.astype(int)

        if value.ndim == 0:
            s = '{} <- {}\n'.format(name, str(value))
        elif value.ndim == 1:
            s = '{} <-\nc({})\n'.format(name, ', '.join(str(v) for v in value))
        elif value.ndim > 1:
            tmpl = '{} <-\nstructure(c({}), .Dim = c({}))\n'
            # transpose value as R uses column-major
            # 'F' = Fortran, column-major
            s = tmpl.format(name,
                            ', '.join(str(v) for v in value.flatten(order='F')),
                            ', '.join(str(v) for v in value.shape))
        parts.append(s)
    return ''.join(parts)


[docs]def stan_rdump(data, filename): """ Dump a dictionary with model data into a file using the R dump format that Stan supports. Parameters ---------- data : dict filename : str """ for name in data: if not is_legal_stan_vname(name): raise ValueError("Variable name {} is not allowed in Stan".format(name)) with open(filename, 'w') as f: f.write(_dict_to_rdump(data))
def _rdump_value_to_numpy(s): """ Convert a R dump formatted value to Numpy equivalent For example, "c(1, 2)" becomes ``array([1, 2])`` Only supports a few R data structures. Will not work with European decimal format. """ if "structure" in s: vector_str, shape_str = re.findall(r'c\([^\)]+\)', s) shape = [int(d) for d in shape_str[2:-1].split(',')] if '.' in vector_str: arr = np.array([float(v) for v in vector_str[2:-1].split(',')]) else: arr = np.array([int(v) for v in vector_str[2:-1].split(',')]) # 'F' = Fortran, column-major arr = arr.reshape(shape, order='F') elif "c(" in s: if '.' in s: arr = np.array([float(v) for v in s[2:-1].split(',')], order='F') else: arr = np.array([int(v) for v in s[2:-1].split(',')], order='F') else: arr = np.array(float(s) if '.' in s else int(s)) return arr def _remove_empty_pars(pars, pars_oi, dims_oi): """ Remove parameters that are actually empty. For example, the parameter y would be removed with the following model code: transformed data { int n; n <- 0; } parameters { real y[n]; } Parameters ---------- pars: iterable of str pars_oi: list of str dims_oi: list of list of int Returns ------- pars_trimmed: list of str """ pars = list(pars) for par, dim in zip(pars_oi, dims_oi): if par in pars and np.prod(dim) == 0: del pars[pars.index(par)] return pars
[docs]def read_rdump(filename): """ Read data formatted using the R dump format Parameters ---------- filename: str Returns ------- data : OrderedDict """ contents = open(filename).read().strip() names = [name.strip() for name in re.findall(r'^(\w+) <-', contents, re.MULTILINE)] values = [value.strip() for value in re.split('\w+ +<-', contents) if value] if len(values) != len(names): raise ValueError("Unable to read file. Unable to pair variable name with value.") d = OrderedDict() for name, value in zip(names, values): d[name.strip()] = _rdump_value_to_numpy(value.strip()) return d
def to_dataframe(fit, pars=None, permuted=False, dtypes=None, inc_warmup=False, diagnostics=True, header=True): """Extract samples as a pandas dataframe for different parameters. Parameters ---------- pars : {str, sequence of str} parameter (or quantile) name(s). permuted : bool If True, returned samples are permuted. If inc_warmup is True, warmup samples have negative order. dtypes : dict datatype of parameter(s). If nothing is passed, float will be used for all parameters. inc_warmup : bool If True, warmup samples are kept; otherwise they are discarded. diagnostics : bool If True, include hmc diagnostics in dataframe. header : bool If True, include header columns. Returns ------- df : pandas dataframe Returned dataframe contains: [header_df]|[draws_df]|[diagnostics_df], where all groups are optional. To exclude draws_df use `pars=[]`. """ try: import pandas as pd except ImportError: raise ImportError("Pandas module not found. You can install pandas with: pip install pandas") fit._verify_has_samples() pars_original = pars if pars is None: pars = fit.sim['pars_oi'] elif isinstance(pars, string_types): pars = [pars] if pars: pars = pystan.misc._remove_empty_pars(pars, fit.sim['pars_oi'], fit.sim['dims_oi']) allpars = fit.sim['pars_oi'] + fit.sim['fnames_oi'] _check_pars(allpars, pars) if dtypes is None: dtypes = {} n_kept = [s if inc_warmup else s-w for s, w in zip(fit.sim['n_save'], fit.sim['warmup2'])] chains = len(fit.sim['samples']) diagnostic_type = {'divergent__':int, 'energy__':float, 'treedepth__':int, 'accept_stat__':float, 'stepsize__':float, 'n_leapfrog__':int} header_dict = OrderedDict() if header: idx = np.concatenate([np.full(n_kept[chain], chain, dtype=int) for chain in range(chains)]) warmup = [np.zeros(n_kept[chain], dtype=np.int64) for chain in range(chains)] if inc_warmup: draw = [] for chain, w in zip(range(chains), fit.sim['warmup2']): warmup[chain][:w] = 1 draw.append(np.arange(n_kept[chain], dtype=np.int64) - w) draw = np.concatenate(draw) else: draw = np.concatenate([np.arange(n_kept[chain], dtype=np.int64) for chain in range(chains)]) warmup = np.concatenate(warmup) header_dict = OrderedDict(zip(['chain', 'draw', 'warmup'], [idx, draw, warmup])) if permuted: if inc_warmup: chain_permutation = [] chain_permutation_order = [] permutation = [] permutation_order = [] for chain, p, w in zip(range(chains), fit.sim['permutation'], fit.sim['warmup2']): chain_permutation.append(list(range(-w, 0)) + p) chain_permutation_order.append(list(range(-w, 0)) + list(np.argsort(p))) permutation.append(sum(n_kept[:chain])+chain_permutation[-1]+w) permutation_order.append(sum(n_kept[:chain])+chain_permutation_order[-1]+w) chain_permutation = np.concatenate(chain_permutation) chain_permutation_order = np.concatenate(chain_permutation_order) permutation = np.concatenate(permutation) permutation_order = np.concatenate(permutation_order) else: chain_permutation = np.concatenate(fit.sim['permutation']) chain_permutation_order = np.concatenate([np.argsort(item) for item in fit.sim['permutation']]) permutation = np.concatenate([sum(n_kept[:chain])+p for chain, p in enumerate(fit.sim['permutation'])]) permutation_order = np.argsort(permutation) header_dict["permutation"] = permutation header_dict["chain_permutation"] = chain_permutation header_dict["permutation_order"] = permutation_order header_dict["chain_permutation_order"] = chain_permutation_order if header: header_df = pd.DataFrame.from_dict(header_dict) else: if permuted: header_df = pd.DataFrame.from_dict({"permutation_order" : header_dict["permutation_order"]}) else: header_df = pd.DataFrame() fnames_set = set(fit.sim['fnames_oi']) pars_set = set(pars) if pars_original is None or fnames_set == pars_set: dfs = [pd.DataFrame.from_dict(pyholder.chains).iloc[-n:] for pyholder, n in zip(fit.sim['samples'], n_kept)] df = pd.concat(dfs, axis=0, sort=False, ignore_index=True) if dtypes: if not fnames_set.issuperset(pars_set): par_keys = OrderedDict([(par, []) for par in fit.sim['pars_oi']]) for key in fit.sim['fnames_oi']: par = key.split("[") par = par[0] par_keys[par].append(key) for par, dtype in dtypes.items(): if isinstance(dtype, (float, np.float64)): continue for key in par_keys.get(par, [par]): df.loc[:, key] = df.loc[:, key].astype(dtype) elif pars: par_keys = dict() if not fnames_set.issuperset(pars_set): par_keys = OrderedDict([(par, []) for par in fit.sim['pars_oi']]) for key in fit.sim['fnames_oi']: par = key.split("[") par = par[0] par_keys[par].append(key) columns = [] for par in pars: columns.extend(par_keys.get(par, [par])) columns = list(np.unique(columns)) df = pd.DataFrame(index=np.arange(sum(n_kept)), columns=columns, dtype=float) for key in columns: key_values = [] for chain, (pyholder, n) in enumerate(zip(fit.sim['samples'], n_kept)): key_values.append(pyholder.chains[key][-n:]) df.loc[:, key] = np.concatenate(key_values) for par, dtype in dtypes.items(): if isinstance(dtype, (float, np.float64)): continue for key in par_keys.get(par, [par]): df.loc[:, key] = df.loc[:, key].astype(dtype) else: df = pd.DataFrame() if diagnostics: diagnostics_dfs = [] for idx, (pyholder, permutation, n) in enumerate(zip(fit.sim['samples'], fit.sim['permutation'], n_kept), 1): diagnostics_df = pd.DataFrame(pyholder['sampler_params'], index=pyholder['sampler_param_names']).T diagnostics_df = diagnostics_df.iloc[-n:, :] for key, dtype in diagnostic_type.items(): if key in diagnostics_df: diagnostics_df.loc[:, key] = diagnostics_df.loc[:, key].astype(dtype) diagnostics_dfs.append(diagnostics_df) if diagnostics_dfs: diagnostics_df = pd.concat(diagnostics_dfs, axis=0, sort=False, ignore_index=True) else: diagnostics_df = pd.DataFrame() else: diagnostics_df = pd.DataFrame() df = pd.concat((header_df, df, diagnostics_df), axis=1, sort=False) if permuted: df.sort_values(by='permutation_order', inplace=True) if not header: df.drop(columns='permutation_order', inplace=True) return df def get_stepsize(fit): """Parse stepsize from fit object Parameters ---------- fit : StanFit4Model Returns ------- list Returns an empty list if step sizes are not found in ``fit.get_adaptation_info``. """ fit._verify_has_samples() stepsizes = [] for adaptation_info in fit.get_adaptation_info(): for line in adaptation_info.splitlines(): if "Step size" in line: stepsizes.append(float(line.split("=")[1].strip())) break return stepsizes def get_inv_metric(fit, as_dict=False): """Parse inverse metric from the fit object Parameters ---------- fit : StanFit4Model as_dict : bool, optional Returns ------- list or dict Returns an empty list if inverse metric is not found in ``fit.get_adaptation_info()``. If `as_dict` returns a dictionary which can be used with `.sampling` method. """ fit._verify_has_samples() inv_metrics = [] if not (("ctrl" in fit.stan_args[0]) and ("sampling" in fit.stan_args[0]["ctrl"])): return inv_metrics metric = [args["ctrl"]["sampling"]["metric"].name for args in fit.stan_args] for adaptation_info, metric_name in zip(fit.get_adaptation_info(), metric): iter_adaptation_info = iter(adaptation_info.splitlines()) inv_metric_list = [] for line in iter_adaptation_info: if any(value in line for value in ["Step size", "Adaptation"]): continue elif "inverse mass matrix" in line: for line in iter_adaptation_info: stripped_set = set(line.replace("# ", "").replace(" ", "").replace(",", "")) if stripped_set.issubset(set(".-1234567890e")): inv_metric = np.array(list(map(float, line.replace("# ", "").strip().split(",")))) if metric_name == "DENSE_E": inv_metric = np.atleast_2d(inv_metric) inv_metric_list.append(inv_metric) else: break inv_metrics.append(np.concatenate(inv_metric_list)) return inv_metrics if not as_dict else dict(enumerate(inv_metrics)) def get_last_position(fit, warmup=False): """Parse last position from fit object Parameters ---------- fit : StanFit4Model warmup : bool If True, returns the last warmup position, when warmup has been done. Otherwise function returns the first sample position. Returns ------- list list contains a dictionary of last draw from each chain. """ fit._verify_has_samples() positions = [] extracted = fit.extract(permuted=False, pars=fit.model_pars, inc_warmup=warmup) draw_location = -1 if warmup: draw_location += max(1, fit.sim["warmup"]) chains = fit.sim["chains"] for i in range(chains): extract_pos = {key : values[draw_location, i] for key, values in extracted.items()} positions.append(extract_pos) return positions