$ pip install mlconfigconfig.yaml
num_classes: 50model: name: LeNetnum_classes: $num_classesoptimizer: name: Adamlr: 1.e-3weight_decay: 1.e-4 ...main.py
importmlconfigfromtorchimportnn, optimfromtorchvisionimportmodelsmlconfig.register(optim.Adam) @mlconfig.registerclassLeNet(nn.Module): def__init__(self, num_classes): super(LeNet, self).__init__() self.num_classes=num_classesself.features=nn.Sequential( nn.Conv2d(1, 6, 5, bias=False), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5, bias=False), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), ) self.classifier=nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(inplace=True), nn.Linear(120, 84), nn.ReLU(inplace=True), nn.Linear(84, self.num_classes), ) defforward(self, x): x=self.features(x) x=x.view(x.size(0), -1) x=self.classifier(x) returnxdefmain(): config=mlconfig.load('config.yaml') config.set_immutable() model=config.model() optimizer=config.optimizer(model.parameters()) ... if__name__=='__main__': main()