| | import torch
|
| | import torch.nn as nn
|
| | from transformers import PreTrainedModel, PretrainedConfig
|
| |
|
| | class FingerNetConfig(PretrainedConfig):
|
| | model_type = "fingernet"
|
| |
|
| | def __init__(
|
| | self,
|
| | x_dim=[6],
|
| | y_dim=[6, 1800],
|
| | h1_dim=[100, 1000],
|
| | h2_dim=[100, 1000],
|
| | **kwargs,
|
| | ):
|
| | super().__init__(**kwargs)
|
| | self.x_dim = x_dim
|
| | self.y_dim = y_dim
|
| | self.h1_dim = h1_dim
|
| | self.h2_dim = h2_dim
|
| |
|
| |
|
| | class FingerNet(PreTrainedModel):
|
| | config_class = FingerNetConfig
|
| |
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| | self.x_dim = config.x_dim
|
| | self.y_dim = config.y_dim
|
| | self.h1_dim = config.h1_dim
|
| | self.h2_dim = config.h2_dim
|
| |
|
| | self.model = nn.ModuleDict()
|
| |
|
| | for i in range(len(self.y_dim)):
|
| | self.model[f"estimator_{i}"] = nn.Sequential(
|
| | nn.Linear(self.x_dim[0], self.h1_dim[i]),
|
| | nn.ReLU(),
|
| | nn.Linear(self.h1_dim[i], self.h2_dim[i]),
|
| | nn.ReLU(),
|
| | nn.Linear(self.h2_dim[i], self.y_dim[i]),
|
| | )
|
| |
|
| |
|
| | self.post_init()
|
| |
|
| | def forward(self, x):
|
| | outputs = []
|
| | for i in range(len(self.y_dim)):
|
| |
|
| | estimator = self.model[f"estimator_{i}"]
|
| | y = estimator(x)
|
| | outputs.append(y)
|
| | return outputs
|
| |
|