Skip to content

A configuration framework for machine learning

License

Notifications You must be signed in to change notification settings

digitaldomain/mlconfig

Repository files navigation

mlconfig

Installation

$ pip install mlconfig

Usage

config.yaml

num_classes: 50model: name: LeNetnum_classes: $num_classesoptimizer: name: Adamlr: 1.e-3weight_decay: 1.e-4 ...

main.py

importmlconfigfromtorchimportnn, optimfromtorchvisionimportmodelsmlconfig.register(optim.Adam) @mlconfig.registerclassLeNet(nn.Module): def__init__(self, num_classes): super(LeNet, self).__init__() self.num_classes=num_classesself.features=nn.Sequential( nn.Conv2d(1, 6, 5, bias=False), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5, bias=False), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), ) self.classifier=nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(inplace=True), nn.Linear(120, 84), nn.ReLU(inplace=True), nn.Linear(84, self.num_classes), ) defforward(self, x): x=self.features(x) x=x.view(x.size(0), -1) x=self.classifier(x) returnxdefmain(): config=mlconfig.load('config.yaml') config.set_immutable() model=config.model() optimizer=config.optimizer(model.parameters()) ... if__name__=='__main__': main()

About

A configuration framework for machine learning

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python99.5%
  • Makefile0.5%