Indexing objects with PyMC coords
<!DOCTYPE html>
Understanding indexing and coords in pymc¶
The following is intended as a reminder for myself as to how indexing works in pymc, so that I will index correctly in the future.
Imaging that we were interested in the behaviour in two species of bear. We might observe some attribute for each bear we see:
import numpy as np
import polars as pl
observation = np.array([0.1, 0.8, 0.9, 0.5, 0.1, 0.2, 0.7, 0.1])
species = np.array(["A", "B", "B", "B", "A", "A", "B", "A"])
pl.DataFrame({"species": species, "observation": observation})
species | observation |
---|---|
str | f64 |
"A" | 0.1 |
"B" | 0.8 |
"B" | 0.9 |
"B" | 0.5 |
"A" | 0.1 |
"A" | 0.2 |
"B" | 0.7 |
"A" | 0.1 |
At the end of our analysis, we're interested in what the mean observation for each bear species is, with some spread. The final answer for the mean might be [0.15, 0.75]
. which a length 2.
First, we need to generate an integer index that we can use for indexing purposes instead of our "A" and "B" labels. Pandas has a nice function for this:
import pandas as pd
species_index, species_unique = pd.factorize(species)
# species_index = [0, 1, 1, 1, 0, 0, 1, 0]
# species_unique = ['A', 'B']
Now, at some point it might be tempting to index the observation
object with the species_index
object:
observation[species_index] # <--- DONT do this
# array([0.1, 0.8, 0.8, 0.8, 0.1, 0.1, 0.8, 0.1])
array([0.1, 0.8, 0.8, 0.8, 0.1, 0.1, 0.8, 0.1])
But if we do that, we just get an array containing the first two observations in a repeated manner!
Instead, we want to only index the distributions themselves. Let's make a dummy collection of distributions, which should represent pm.Normal("bear_attribute", dims="species")
(we can instantiate that without a model, and this is a useful excersize to see behind the scenes).
from dataclasses import dataclass
@dataclass
class OneNormalDist:
name: str
PyMCNormal = np.array(
[OneNormalDist("A"), OneNormalDist("B")]
) # this has the same dimensionality as the unique number of species
This is the object that we want to index on! This will tell pymc which distribution object (A or B) should be associated with which observation!
PyMCNormal[species_index]
# array([OneNormalDist(name='A'), OneNormalDist(name='B'),
# OneNormalDist(name='B'), OneNormalDist(name='B'),
# OneNormalDist(name='A'), OneNormalDist(name='A'),
# OneNormalDist(name='B'), OneNormalDist(name='A')], dtype=object)
array([OneNormalDist(name='A'), OneNormalDist(name='B'), OneNormalDist(name='B'), OneNormalDist(name='B'), OneNormalDist(name='A'), OneNormalDist(name='A'), OneNormalDist(name='B'), OneNormalDist(name='A')], dtype=object)
# Here's a dataframe representation to get the idea across:
pl.DataFrame({"species": species, "observation": observation, "distribution":PyMCNormal[species_index]})
species | observation | distribution |
---|---|---|
str | f64 | object |
"A" | 0.1 | OneNormalDist(name='A') |
"B" | 0.8 | OneNormalDist(name='B') |
"B" | 0.9 | OneNormalDist(name='B') |
"B" | 0.5 | OneNormalDist(name='B') |
"A" | 0.1 | OneNormalDist(name='A') |
"A" | 0.2 | OneNormalDist(name='A') |
"B" | 0.7 | OneNormalDist(name='B') |
"A" | 0.1 | OneNormalDist(name='A') |
Here's how you would model our bear-case in pymc with this indexing:
import pymc as pm
with pm.Model(coords = {"species":species_unique}) as model:
bear_mu = pm.Normal(name="bear_attr_mu", mu=0.5, sigma=0.1, dims="species")
pm.Normal(
name="likelihood",
mu=bear_mu[species_index], # <--- We index here!
sigma=0.5,
observed=observation
)
trace = pm.sample()
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [bear_attr_mu]
Output()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
trace["posterior"]["bear_attr_mu"].mean(dim=("chain", "draw"))
<xarray.DataArray 'bear_attr_mu' (species: 2)> Size: 16B array([0.44812854, 0.53058651]) Coordinates: * species (species) <U1 8B 'A' 'B'
import arviz as az
az.plot_posterior(trace);