Skip to content

Sample data for training? #29

@WeileiZeng

Description

@WeileiZeng

After constructing the model using the following code,

#./model/run.py
from nequip import model_from_config, default_config
cfg = default_config()
cfg.scale=1.0
cfg.shift=0.0
model=model_from_config(cfg)
print(model)  #sucessfully constructed the model

How can we get the training data, which has the following format?

# model name: NequiPEneryModel in  ./model/nequip.py
# model input
graph = jraph.GraphsTuple(
        nodes=nodes,
	edges=edges,
        receivers=receivers,
        senders=senders,
	globals=globals_,
        n_node=n_node,
        n_edge=n_edge,
    )

# model output
partial = functools.partial
tree_map = partial(
    jax.tree_map, is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray)
)
global_output = tree_map(
        lambda n: jraph.segment_sum(n, node_gr_idx, n_graph), atomic_output
    )
# global_output is the output

# in one line, the output is
global_output = jax.tree_map(
              is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray),
              lambda n: jraph.segment_sum(n, node_gr_idx, n_graph),
	      atomic_output
    )
# where atomic_output is the output of a neural network

Originally posted by @WeileiZeng in #28 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions