Hello,
I am writing to ask a couple of questions regarding your code base. I would like to use it to estimate the average treatment effect of a confounded dataset, with treatment $X$, outcome $Y$, and four pretreatment covariates $Z_1, Z_2, Z_3, Z_4$. We have the following factorisation of the pretreatment_covariate_joint $P(Z_1)~P(Z_2\mid Z_1)~P(Z_3\mid Z_1,~Z_2)~P(Z_4\mid Z_1,~Z_2,~Z_3)$. Our aim is to infer the ATE of $X$ on $Y$. We define the following class, pasted below
import torch
import torch.nn.functional as F
from causal_nf.sem_equations.sem_base import SEM
class TestModel(SEM):
def __init__(self):
functions = None
inverses = None
super().__init__(functions, inverses, None)
def adjacency(self, add_diag=False):
adj = torch.zeros((6, 6))
adj[0, :] = torch.tensor([0, 0, 0, 0, 0, 0]) # Z1
adj[1, :] = torch.tensor([1, 0, 0, 0, 0, 0]) # Z2
adj[2, :] = torch.tensor([1, 1, 0, 0, 0, 0]) # Z3
adj[3, :] = torch.tensor([1, 1, 1, 0, 0, 0]) # Z4
adj[4, :] = torch.tensor([1, 1, 1, 1, 0, 0]) # X
adj[5, :] = torch.tensor([1, 1, 1, 1, 1, 0]) # Y
if add_diag:
adj += torch.eye(6)
return adj
def intervention_index_list(self):
return [0, 4]
I have made custom Preparator, DataLoader classes, and a config file. The model has already been fit and in the code I am loading it from the last checkpoint. I run the following code, which looks to estimate the ATE from 5 different samples from the fitted model. However, the ATE estimates do not appear to be consistently close with the true value in my benchmark dataset, and I wonder if you could point out any possible issues in the code pasted below:
import causal_nf.config as causal_nf_config
from causal_nf.config import cfg
import causal_nf.utils.training as causal_nf_train
from yacs.config import CfgNode
import torch
import causal_nf.utils.io as causal_nf_io
import numpy as np
from causal_nf.preparators.MY_preparator import MYPreparator
from causal_nf.config import cfg
seed = 10
args_list = []
args = CfgNode({‘config_file’: f’{folder}/{ckpt_code}/wandb_local/config_local.yaml’,
‘config_default_file’: f’{folder}/{ckpt_code}/wandb_local/default_config.yaml’,
‘project’: None, ‘wandb_mode’: ‘disabled’, ‘wandb_group’: None,
‘load_model’: f’{folder}/{ckpt_code}’, ‘delete_ckpt’: False})
config = causal_nf_config.build_config(
config_file=args.config_file,
args_list=args_list,
config_default_file=args.config_default_file,
)
causal_nf_config.assert_cfg_and_config(cfg, config)
preparator = MYPreparator.loader(cfg.dataset)
preparator.prepare_data()
model_lightning = causal_nf_train.load_model(cfg=cfg, preparator=preparator, ckpt_file=check_file)
model = model_lightning.model
model.eval()
loaders = preparator.get_dataloaders(
batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers
)
n_rounds = 5
ates = []
seeds = np.arange(n_rounds)
batch = next(iter(loaders[-1]))
for i, seed in enumerate(seeds):
int_dict = {‘name’: ‘1_0’, ‘a’: 1., ‘b’: 0., ‘index’: 4}
name = int_dict[“name”]
a = int_dict[“a”]#1.
b = int_dict[“b”]#0.
index = int_dict[“index”]
torch.random.manual_seed(seed)
ate = model_lightning.model.compute_ate(
index,
a=a,
b=b,
num_samples=10000,
scaler=preparator.scaler_transform,
)
ates.append(ate.detach().numpy())
print(ates[-1])
Thanks!
Hello,
I am writing to ask a couple of questions regarding your code base. I would like to use it to estimate the average treatment effect of a confounded dataset, with treatment$X$ , outcome $Y$ , and four pretreatment covariates $Z_1, Z_2, Z_3, Z_4$ . We have the following factorisation of the pretreatment_covariate_joint $P(Z_1)~P(Z_2\mid Z_1)~P(Z_3\mid Z_1,~Z_2)~P(Z_4\mid Z_1,~Z_2,~Z_3)$ . Our aim is to infer the ATE of $X$ on $Y$ . We define the following class, pasted below
I have made custom
Preparator,DataLoaderclasses, and aconfigfile. The model has already been fit and in the code I am loading it from the last checkpoint. I run the following code, which looks to estimate the ATE from 5 different samples from the fitted model. However, the ATE estimates do not appear to be consistently close with the true value in my benchmark dataset, and I wonder if you could point out any possible issues in the code pasted below:Thanks!