Skip to content

jim

Jim ¤

Bases: object

Master class for interfacing with flowMC

Source code in jimgw/jim.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class Jim(object):
    """
    Master class for interfacing with flowMC

    """

    def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
        self.Likelihood = likelihood
        self.Prior = prior
        seed = kwargs.get("seed", 0)
        n_chains = kwargs.get("n_chains", 20)

        rng_key_set = initialize_rng_keys(n_chains, seed=seed)
        num_layers = kwargs.get("num_layers", 10)
        hidden_size = kwargs.get("hidden_size", [128,128])
        num_bins = kwargs.get("num_bins", 8)

        local_sampler_arg = kwargs.get("local_sampler_arg", {})

        local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix

        model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1])
        self.Sampler = Sampler(
            self.Prior.n_dim,
            rng_key_set,
            None,
            local_sampler,
            model,
            **kwargs)


    def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 100, n_loops: int = 2000, seed = 92348):
        bounds = jnp.array(bounds).T
        key = jax.random.PRNGKey(seed)
        set_nwalkers = set_nwalkers
        initial_guess = self.Prior.sample(key, set_nwalkers)

        y = lambda x: -self.posterior(x, None)
        y = jax.jit(jax.vmap(y))
        print("Compiling likelihood function")
        y(initial_guess)
        print("Done compiling")

        print("Starting the optimizer")
        optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose = True)
        state = optimizer.optimize(y, bounds, n_loops=n_loops)
        best_fit = optimizer.get_result()[0]
        return best_fit

    def posterior(self, params: Array, data: dict):
        named_params = self.Prior.add_name(params, transform_name=True, transform_value=True)
        return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(params)

    def sample(self, key: jax.random.PRNGKey,
               initial_guess: Array = None):
        if initial_guess is None:
            initial_guess = self.Prior.sample(key, self.Sampler.n_chains)
        self.Sampler.sample(initial_guess, None)

    def print_summary(self):
        """
        Generate summary of the run

        """

        train_summary = self.Sampler.get_sampler_state(training=True)
        production_summary = self.Sampler.get_sampler_state(training=False)

        training_chain: Array = train_summary["chains"]
        training_log_prob: Array = train_summary["log_prob"]
        training_local_acceptance: Array = train_summary["local_accs"]
        training_global_acceptance: Array = train_summary["global_accs"]
        training_loss: Array = train_summary["loss_vals"]

        production_chain: Array = production_summary["chains"]
        production_log_prob: Array = production_summary["log_prob"]
        production_local_acceptance: Array = production_summary["local_accs"]
        production_global_acceptance: Array = production_summary["global_accs"]

        print("Training summary")
        print('=' * 10)
        for index in range(len(self.Prior.naming)):
            print(f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}")
        print(f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}") 
        print(f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}")
        print(f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}")
        print(f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}")

        print("Production summary")
        print('=' * 10)
        for index in range(len(self.Prior.naming)):
            print(f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}")
        print(f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}")
        print(f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}")
        print(f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}")

    def get_samples(self, training: bool = False) -> dict:
        """
        Get the samples from the sampler

        Args:
            training (bool, optional): If True, return the training samples. Defaults to False.

        Returns:
            Array: Samples
        """
        if training:
            chains = self.Sampler.get_sampler_state(training=True)["chains"]
        else:
            chains = self.Sampler.get_sampler_state(training=False)["chains"]

        chains = self.Prior.add_name(chains.transpose(2,0,1), transform_name=True)
        return chains

    def plot(self):
        pass

get_samples(training=False) ¤

Get the samples from the sampler

Parameters:

Name Type Description Default
training bool

If True, return the training samples. Defaults to False.

False

Returns:

Name Type Description
Array dict

Samples

Source code in jimgw/jim.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def get_samples(self, training: bool = False) -> dict:
    """
    Get the samples from the sampler

    Args:
        training (bool, optional): If True, return the training samples. Defaults to False.

    Returns:
        Array: Samples
    """
    if training:
        chains = self.Sampler.get_sampler_state(training=True)["chains"]
    else:
        chains = self.Sampler.get_sampler_state(training=False)["chains"]

    chains = self.Prior.add_name(chains.transpose(2,0,1), transform_name=True)
    return chains

print_summary() ¤

Generate summary of the run

Source code in jimgw/jim.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def print_summary(self):
    """
    Generate summary of the run

    """

    train_summary = self.Sampler.get_sampler_state(training=True)
    production_summary = self.Sampler.get_sampler_state(training=False)

    training_chain: Array = train_summary["chains"]
    training_log_prob: Array = train_summary["log_prob"]
    training_local_acceptance: Array = train_summary["local_accs"]
    training_global_acceptance: Array = train_summary["global_accs"]
    training_loss: Array = train_summary["loss_vals"]

    production_chain: Array = production_summary["chains"]
    production_log_prob: Array = production_summary["log_prob"]
    production_local_acceptance: Array = production_summary["local_accs"]
    production_global_acceptance: Array = production_summary["global_accs"]

    print("Training summary")
    print('=' * 10)
    for index in range(len(self.Prior.naming)):
        print(f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}")
    print(f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}") 
    print(f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}")
    print(f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}")
    print(f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}")

    print("Production summary")
    print('=' * 10)
    for index in range(len(self.Prior.naming)):
        print(f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}")
    print(f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}")
    print(f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}")
    print(f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}")