Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions examples/mnist/batch_eth_mnist.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -32,7 +32,7 @@
parser.add_argument("--n_test", type=int, default=10000)
parser.add_argument("--n_train", type=int, default=60000)
parser.add_argument("--n_workers", type=int, default=-1)
parser.add_argument("--update_steps", type=int, default=256)
parser.add_argument("--n_updates", type=int, default=10)
parser.add_argument("--exc", type=float, default=22.5)
parser.add_argument("--inh", type=float, default=120)
parser.add_argument("--theta_plus", type=float, default=0.05)
Expand All@@ -44,7 +44,7 @@
parser.add_argument("--test", dest="train", action="store_false")
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=True, gpu=True)
parser.set_defaults(plot=False, gpu=True)

args = parser.parse_args()

Expand All@@ -55,7 +55,7 @@
n_test = args.n_test
n_train = args.n_train
n_workers = args.n_workers
update_steps = args.update_steps
n_updates = args.n_updates
exc = args.exc
inh = args.inh
theta_plus = args.theta_plus
Expand All@@ -67,6 +67,7 @@
plot = args.plot
gpu = args.gpu

update_steps = int(n_train / batch_size / n_updates)
update_interval = update_steps * batch_size

device = "cpu"
Expand DownExpand Up@@ -162,14 +163,14 @@
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)

# Train the network.
print("\nBegin training.\n")
print("\nBegin training...")
start = t()

for epoch in range(n_epochs):
labels = []

if epoch % progress_interval == 0:
print("\n Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
print("\nProgress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
start = t()

# Create a dataloader to iterate and batch data
Expand All@@ -183,13 +184,10 @@

pbar_training = tqdm(total=n_train)
for step, batch in enumerate(train_dataloader):
if step > n_train:
if step * batch_size > n_train:
break
# Get next input sample.
inputs ={"X": batch["encoded_image"]}
if gpu:
inputs ={k: v.cuda() for k, v in inputs.items()}

# Assign labels to excitatory neurons.
if step % update_steps == 0 and step > 0:
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels, device=device)
Expand DownExpand Up@@ -245,6 +243,12 @@

labels = []

# Get next input sample.
inputs ={"X": batch["encoded_image"]}
if gpu:
inputs ={k: v.cuda() for k, v in inputs.items()}

# Remember labels.
labels.extend(batch["label"].tolist())

# Run the network on the input.
Expand DownExpand Up@@ -293,9 +297,10 @@

network.reset_state_variables() # Reset state variables.
pbar_training.update(batch_size)
pbar_training.close()

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
print("\nTraining complete.\n")

# Load MNIST data.
test_dataset = MNIST(
Expand All@@ -322,13 +327,15 @@
accuracy ={"all": 0, "proportion": 0}

# Train the network.
print("\nBegin testing\n")
print("\nBegin testing...\n")
network.train(mode=False)
start = t()

pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
if step > n_test:
pbar.set_description_str("Test progress: ")

for step, batch in enumerate(test_dataloader):
if step * batch_size > n_test:
break
# Get next input sample.
inputs ={"X": batch["encoded_image"]}
Expand DownExpand Up@@ -362,11 +369,11 @@
)

network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Test progress: ")
pbar.update()
pbar.update(batch_size)
pbar.close()

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
print("\nTesting complete.\n")