MACE

Resources

Foundation model

2023-12-10-mace-128-L0_epoch-199.model is defined as

ScaleShiftMACE(
  (node_embedding): LinearNodeEmbeddingBlock(
    (linear): Linear(89x0e -> 128x0e | 11392 weights)
  )
  (radial_embedding): RadialEmbeddingBlock(
    (bessel_fn): BesselBasis(r_max=6.0, num_basis=10, trainable=False)
    (cutoff_fn): PolynomialCutoff(p=5.0, r_max=6.0)
  )
  (spherical_harmonics): SphericalHarmonics()
  (atomic_energies_fn): AtomicEnergiesBlock(energies=[-3.6672, -1.3321, -3.4821, -4.7367, -7.7249, -8.4056, -7.3601, -7.2846, -4.8965, 0.0000, -2.7594, -2.8140, -4.8469, -7.6948, -6.9633, -4.6726, -2.8117, -0.0626, -2.6176, -5.3905, -7.8858, -10.2684, -8.6651, -9.2331, -8.3050, -7.0490, -5.5774, -5.1727, -3.2521, -1.2902, -3.5271, -4.7085, -3.9765, -3.8862, -2.5185, 6.7669, -2.5635, -4.9380, -10.1498, -11.8469, -12.1389, -8.7917, -8.7869, -7.7809, -6.8500, -4.8910, -2.0634, -0.6396, -2.7887, -3.8186, -3.5871, -2.8804, -1.6356, 9.8467, -2.7653, -4.9910, -8.9337, -8.7356, -8.0190, -8.2515, -7.5917, -8.1697, -13.5927, -18.5175, -7.6474, -8.1230, -7.6078, -6.8503, -7.8269, -3.5848, -7.4554, -12.7963, -14.1081, -9.3549, -11.3875, -9.6219, -7.3244, -5.3047, -2.3801, 0.2495, -2.3240, -3.7300, -3.4388, -5.0629, -11.0246, -12.2656, -13.8556, -14.9331, -15.2828])
  (interactions): ModuleList(
    (0): RealAgnosticResidualInteractionBlock(
      (linear_up): Linear(128x0e -> 128x0e | 16384 weights)
      (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)
      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]
      (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)
      (skip_tp): FullyConnectedTensorProduct(128x0e x 89x0e -> 128x0e | 1458176 paths | 1458176 weights)
      (reshape): reshape_irreps()
    )
    (1): RealAgnosticResidualInteractionBlock(
      (linear_up): Linear(128x0e -> 128x0e | 16384 weights)
      (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)
      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]
      (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)
      (skip_tp): FullyConnectedTensorProduct(128x0e x 89x0e -> 128x0e | 1458176 paths | 1458176 weights)
      (reshape): reshape_irreps()
    )
  )
  (products): ModuleList(
    (0): EquivariantProductBasisBlock(
      (symmetric_contractions): SymmetricContraction(
        (contractions): ModuleList(
          (0): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.DoubleTensor of size 89x4x128]
                (1): Parameter containing: [torch.DoubleTensor of size 89x1x128]
            )
            (graph_opt_main): GraphModule()
          )
        )
      )
      (linear): Linear(128x0e -> 128x0e | 16384 weights)
    )
    (1): EquivariantProductBasisBlock(
      (symmetric_contractions): SymmetricContraction(
        (contractions): ModuleList(
          (0): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.DoubleTensor of size 89x4x128]
                (1): Parameter containing: [torch.DoubleTensor of size 89x1x128]
            )
            (graph_opt_main): GraphModule()
          )
        )
      )
      (linear): Linear(128x0e -> 128x0e | 16384 weights)
    )
  )
  (readouts): ModuleList(
    (0): LinearReadoutBlock(
      (linear): Linear(128x0e -> 1x0e | 128 weights)
    )
    (1): NonLinearReadoutBlock(
      (linear_1): Linear(128x0e -> 16x0e | 2048 weights)
      (non_linearity): Activation [x] (16x0e -> 16x0e)
      (linear_2): Linear(16x0e -> 1x0e | 16 weights)
    )
  )
  (scale_shift): ScaleShiftBlock(scale=0.804154, shift=0.164097)
)