diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index fd240de6..5a413912 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -139,7 +139,9 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ self.batch_size = batch_size - self.s = torch.zeros(batch_size, *self.shape, device=self.s.device) + self.s = torch.zeros( + batch_size, *self.shape, device=self.s.device, dtype=torch.bool + ) if self.traces: self.x = torch.zeros(batch_size, *self.shape, device=self.x.device)