diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index d57f8cec..b29113f4 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -4,14 +4,17 @@ import numpy as np import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair from bindsnet.utils import im2col_indices - from ..network.nodes import SRM0Nodes from ..network.topology import ( AbstractConnection, Connection, + Conv1dConnection, Conv2dConnection, + Conv3dConnection, LocalConnection, ) @@ -173,8 +176,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update + elif isinstance(connection, Conv3dConnection): + self.update = self._conv3d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -206,6 +213,41 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _conv1d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``Conv1dConnection`` subclass of + ``AbstractConnection`` class. + """ + # Get convolutional layer parameters. + out_channels, in_channels, kernel_size = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad(self.source.x, _pair(padding)) + source_x = source_x.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad(self.source.s.float(), _pair(padding)) + source_s = source_s.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -247,6 +289,67 @@ def _conv2d_connection_update(self, **kwargs) -> None: super().update() + def _conv3d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``Conv3dConnection`` subclass of + ``AbstractConnection`` class. + """ + # Get convolutional layer parameters. + ( + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad( + self.source.x, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_x = ( + source_x.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad( + self.source.s, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_s = ( + source_s.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + class WeightDependentPostPre(LearningRule): # language=rst @@ -293,8 +396,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -329,6 +436,57 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _conv1d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``Conv1dConnection`` subclass of + ``AbstractConnection`` class. + """ + # Get convolutional layer parameters. + ( + out_channels, + in_channels, + kernel_size, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad(self.source.x, _pair(padding)) + source_x = source_x.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad(self.source.s.float(), _pair(padding)) + source_s = source_s.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + update = 0 + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + update -= ( + self.nu[0] + * pre.view(self.connection.w.size()) + * (self.connection.w - self.wmin) + ) + + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + update += ( + self.nu[1] + * post.view(self.connection.w.size()) + * (self.wmax - self.connection.wmin) + ) + + self.connection.w += update + + super().update() + def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -387,6 +545,79 @@ def _conv2d_connection_update(self, **kwargs) -> None: super().update() + def _conv3d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``Conv3dConnection`` subclass of + ``AbstractConnection`` class. + """ + # Get convolutional layer parameters. + ( + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad( + self.source.x, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_x = ( + source_x.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad( + self.source.s, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_s = ( + source_s.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + update = 0 + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + update -= ( + self.nu[0] + * pre.view(self.connection.w.size()) + * (self.connection.w - self.wmin) + ) + + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + update += ( + self.nu[1] + * post.view(self.connection.w.size()) + * (self.wmax - self.connection.wmin) + ) + + self.connection.w += update + + super().update() + class Hebbian(LearningRule): # language=rst @@ -427,8 +658,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update + elif isinstance(connection, Conv3dConnection): + self.update = self._conv3d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -457,6 +692,38 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _conv1d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Hebbian learning rule for ``Conv2dConnection`` subclass of + ``AbstractConnection`` class. + """ + out_channels, in_channels, kernel_size = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad(self.source.x, _pair(padding)) + source_x = source_x.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad(self.source.s.float(), _pair(padding)) + source_s = source_s.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Pre-synaptic update. + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -491,6 +758,64 @@ def _conv2d_connection_update(self, **kwargs) -> None: super().update() + def _conv3d_connection_update(self, **kwargs) -> None: + # language=rst + """ + Hebbian learning rule for ``Conv2dConnection`` subclass of + ``AbstractConnection`` class. + """ + ( + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + batch_size = self.source.batch_size + + # Reshaping spike traces and spike occurrences. + source_x = F.pad( + self.source.x, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_x = ( + source_x.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_x = self.target.x.view(batch_size, out_channels, -1) + source_s = F.pad( + self.source.s, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_s = ( + source_s.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Pre-synaptic update. + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + class MSTDP(LearningRule): # language=rst @@ -534,8 +859,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update + elif isinstance(connection, Conv3dConnection): + self.update = self._conv3d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -608,10 +937,10 @@ def _connection_update(self, **kwargs) -> None: super().update() - def _conv2d_connection_update(self, **kwargs) -> None: + def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ - MSTDP learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` + MSTDP learning rule for ``Conv1dConnection`` subclass of ``AbstractConnection`` class. Keyword arguments: @@ -638,8 +967,80 @@ def _conv2d_connection_update(self, **kwargs) -> None: kwargs.get("a_minus", -1.0), device=self.connection.w.device ) + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + self.connection.w += self.nu[0] * torch.sum(update, dim=0) + + out_channels, in_channels, kernel_size = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = F.pad(self.p_plus, _pair(padding)) + self.p_plus = self.p_plus.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() + + # Reshaping spike occurrences. + source_s = F.pad(self.source.s.float(), _pair(padding)) + source_s = source_s.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) + self.eligibility = self.eligibility.view(self.connection.w.size()) + + super().update() + + def _conv2d_connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` + class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ batch_size = self.source.batch_size + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + # Compute weight update based on the eligibility value of the past timestep. update = reward * self.eligibility self.connection.w += self.nu[0] * torch.sum(update, dim=0) @@ -685,6 +1086,112 @@ def _conv2d_connection_update(self, **kwargs) -> None: super().update() + def _conv3d_connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``Conv3dConnection`` subclass of ``AbstractConnection`` + class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ + batch_size = self.source.batch_size + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + self.connection.w += self.nu[0] * torch.sum(update, dim=0) + + ( + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = F.pad( + self.p_plus, + ( + padding[0], + padding[0], + padding[1], + padding[1], + padding[2], + padding[2], + ), + ) + self.p_plus = ( + self.p_plus.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() + + # Reshaping spike occurrences. + source_s = F.pad( + self.source.s, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_s = ( + source_s.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_s = self.target.s.view(batch_size, out_channels, -1).float() + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) + self.eligibility = self.eligibility.view(self.connection.w.size()) + + super().update() + class MSTDPET(LearningRule): # language=rst @@ -729,8 +1236,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, Conv1dConnection): + self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update + elif isinstance(connection, Conv3dConnection): + self.update = self._conv3d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -803,6 +1314,89 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _conv1d_connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``Conv1dConnection`` subclass of + ``AbstractConnection`` class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ + batch_size = self.source.batch_size + + # Initialize eligibility and eligibility trace. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + if not hasattr(self, "eligibility_trace"): + self.eligibility_trace = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) + + # Compute weight update. + update = reward * self.eligibility_trace + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) + + out_channels, in_channels, kernel_size = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = F.pad(self.p_plus.float(), _pair(padding)) + self.p_plus = self.p_plus.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() + + # Reshaping spike occurrences. + source_s = F.pad(self.source.s.float(), _pair(padding)) + source_s = source_s.unfold(-1, kernel_size, stride).reshape( + batch_size, -1, in_channels * kernel_size + ) + target_s = ( + self.target.s.permute(1, 2, 0).view(batch_size, out_channels, -1).float() + ) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) + self.eligibility = self.eligibility.view(self.connection.w.size()) + + super().update() + def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -888,6 +1482,124 @@ def _conv2d_connection_update(self, **kwargs) -> None: super().update() + def _conv3d_connection_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``Conv3dConnection`` subclass of + ``AbstractConnection`` class. + + Keyword arguments: + + :param Union[float, torch.Tensor] reward: Reward signal from reinforcement + learning task. + :param float a_plus: Learning rate (post-synaptic). + :param float a_minus: Learning rate (pre-synaptic). + """ + batch_size = self.source.batch_size + + # Initialize eligibility and eligibility trace. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + if not hasattr(self, "eligibility_trace"): + self.eligibility_trace = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) + + # Compute weight update. + update = reward * self.eligibility_trace + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) + + ( + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ) = self.connection.w.size() + padding, stride = self.connection.padding, self.connection.stride + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = F.pad( + self.p_plus, + ( + padding[0], + padding[0], + padding[1], + padding[1], + padding[2], + padding[2], + ), + ) + self.p_plus = ( + self.p_plus.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() + + # Reshaping spike occurrences. + source_s = F.pad( + self.source.s, + (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2]), + ) + source_s = ( + source_s.unfold(-3, kernel_width, stride[0]) + .unfold(-3, kernel_height, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + -1, + in_channels * kernel_width * kernel_height * kernel_depth, + ) + ) + target_s = ( + self.target.s.permute(1, 2, 3, 4, 0) + .view(batch_size, out_channels, -1) + .float() + ) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) + self.eligibility = self.eligibility.view(self.connection.w.size()) + + super().update() + class Rmax(LearningRule): # language=rst diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 3b166f18..d0677f7c 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from torch.nn import Module, Parameter -from torch.nn.modules.utils import _pair +from torch.nn.modules.utils import _pair, _triple from bindsnet.network.nodes import CSRMNodes, Nodes @@ -253,10 +253,156 @@ def reset_state_variables(self) -> None: super().reset_state_variables() +class Conv1dConnection(AbstractConnection): + # language=rst + """ + Specifies one-dimensional convolutional synapses between one or two populations of neurons. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a ``Conv1dConnection`` object. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: the size of 1-D convolutional kernel. + :param stride: stride for convolution. + :param padding: padding for convolution. + :param dilation: dilation for convolution. + :param nu: Learning rate for both pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the minibatch + dimension. + :param weight_decay: Constant multiple to decay weights by on each iteration. + + Keyword arguments: + + :param LearningRule update_rule: Modifies connection parameters according to + some rule. + :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor b: Target population bias. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param float norm: Total weight per target neuron normalization constant. + """ + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) + + if dilation != 1: + raise NotImplementedError( + "Dilation is not currently supported for 1-D spiking convolution." + ) + + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + + self.in_channels, input_size = ( + source.shape[0], + source.shape[1], + ) + self.out_channels, output_size = ( + target.shape[0], + target.shape[1], + ) + + conv_size = (input_size - self.kernel_size + 2 * self.padding) / self.stride + 1 + shape = (self.in_channels, self.out_channels, int(conv_size)) + + error = ( + "Target dimensionality must be (out_channels, ?," + "(input_size - filter_size + 2 * padding) / stride + 1," + ) + + assert target.shape[0] == shape[1] and target.shape[1] == shape[2], error + + w = kwargs.get("w", None) + inf = torch.tensor(np.inf) + if w is None: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): + w = torch.clamp( + torch.rand(self.out_channels, self.in_channels, self.kernel_size), + self.wmin, + self.wmax, + ) + else: + w = (self.wmax - self.wmin) * torch.rand( + self.out_channels, self.in_channels, self.kernel_size + ) + w += self.wmin + else: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): + w = torch.clamp(w, self.wmin, self.wmax) + + self.w = Parameter(w, requires_grad=False) + self.b = Parameter( + kwargs.get("b", torch.zeros(self.out_channels)), requires_grad=False + ) + + def compute(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """ + Compute convolutional pre-activations given spikes using layer weights. + + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + return F.conv1d( + s.float(), + self.w, + self.b, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + Normalize weights along the first axis according to total weight per target + neuron. + """ + if self.norm is not None: + # get a view and modify in place + w = self.w.view(self.w.shape[0] * self.w.shape[1], self.w.shape[2]) + + for fltr in range(w.shape[0]): + w[fltr] *= self.norm / w[fltr].sum(0) + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + class Conv2dConnection(AbstractConnection): # language=rst """ - Specifies convolutional synapses between one or two populations of neurons. + Specifies two-dimensional convolutional synapses between one or two populations of neurons. """ def __init__( @@ -408,6 +554,269 @@ def reset_state_variables(self) -> None: super().reset_state_variables() +class Conv3dConnection(AbstractConnection): + # language=rst + """ + Specifies three-dimensional convolutional synapses between one or two populations of neurons. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a ``Conv3dConnection`` object. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: Depth-wise, horizontal, and vertical size of convolutional kernels. + :param stride: Depth-wise, horizontal, and vertical stride for convolution. + :param padding: Depth-wise, horizontal, and vertical padding for convolution. + :param dilation: Depth-wise, horizontal and vertical dilation for convolution. + :param nu: Learning rate for both pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the minibatch + dimension. + :param weight_decay: Constant multiple to decay weights by on each iteration. + + Keyword arguments: + + :param LearningRule update_rule: Modifies connection parameters according to + some rule. + :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor b: Target population bias. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param float norm: Total weight per target neuron normalization constant. + """ + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) + + if dilation != 1 and dilation != (1, 1, 1): + raise NotImplementedError( + "Dilation is not currently supported for 3-D spiking convolution." + ) + + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride) + self.padding = _triple(padding) + self.dilation = _triple(dilation) + + self.in_channels, input_depth, input_height, input_width = ( + source.shape[0], + source.shape[1], + source.shape[2], + source.shape[3], + ) + self.out_channels, output_depth, output_height, output_width = ( + target.shape[0], + target.shape[1], + target.shape[2], + target.shape[3], + ) + + depth = (input_depth - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[ + 0 + ] + 1 + width = ( + input_height - self.kernel_size[1] + 2 * self.padding[1] + ) / self.stride[1] + 1 + height = ( + input_width - self.kernel_size[2] + 2 * self.padding[2] + ) / self.stride[2] + 1 + + shape = ( + self.in_channels, + self.out_channels, + int(depth), + int(width), + int(height), + ) + + error = ( + "Target dimensionality must be (out_channels, ?," + "(input_depth - filter_depth + 2 * padding_depth) / stride_depth + 1," + "(input_height - filter_height + 2 * padding_height) / stride_height + 1," + "(input_width - filter_width + 2 * padding_width) / stride_width + 1" + ) + + assert ( + target.shape[0] == shape[1] + and target.shape[1] == shape[2] + and target.shape[2] == shape[3] + and target.shape[3] == shape[4] + ), error + + w = kwargs.get("w", None) + inf = torch.tensor(np.inf) + if w is None: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): + w = torch.clamp( + torch.rand(self.out_channels, self.in_channels, *self.kernel_size), + self.wmin, + self.wmax, + ) + else: + w = (self.wmax - self.wmin) * torch.rand( + self.out_channels, self.in_channels, *self.kernel_size + ) + w += self.wmin + else: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): + w = torch.clamp(w, self.wmin, self.wmax) + + self.w = Parameter(w, requires_grad=False) + self.b = Parameter( + kwargs.get("b", torch.zeros(self.out_channels)), requires_grad=False + ) + + def compute(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """ + Compute convolutional pre-activations given spikes using layer weights. + + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + return F.conv3d( + s.float(), + self.w, + self.b, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + Normalize weights along the first axis according to total weight per target + neuron. + """ + if self.norm is not None: + # get a view and modify in place + w = self.w.view( + self.w.shape[0] * self.w.shape[1], + self.w.shape[2] * self.w.shape[3] * self.w.shape[4], + ) + + for fltr in range(w.shape[0]): + w[fltr] *= self.norm / w[fltr].sum(0) + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + +class MaxPool1dConnection(AbstractConnection): + # language=rst + """ + Specifies max-pooling synapses between one or two populations of neurons by keeping + online estimates of maximally firing neurons. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a ``MaxPool1dConnection`` object. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: the size of 1-D convolutional kernel. + :param stride: stride for convolution. + :param padding: padding for convolution. + :param dilation: dilation for convolution. + + Keyword arguments: + + :param decay: Decay rate of online estimates of average firing activity. + """ + super().__init__(source, target, None, None, 0.0, **kwargs) + + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + + self.register_buffer("firing_rates", torch.zeros(source.s.shape)) + + def compute(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """ + Compute max-pool pre-activations given spikes using online firing rate + estimates. + + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + self.firing_rates -= self.decay * self.firing_rates + self.firing_rates += s.float().squeeze() + + _, indices = F.max_pool1d( + self.firing_rates, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=True, + ) + + return s.flatten(2).gather(2, indices.flatten(2)).view_as(indices).float() + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + No weights -> no normalization. + """ + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + self.firing_rates = torch.zeros(self.source.s.shape) + + class MaxPool2dConnection(AbstractConnection): # language=rst """ @@ -496,6 +905,94 @@ def reset_state_variables(self) -> None: self.firing_rates = torch.zeros(self.source.s.shape) +class MaxPoo3dConnection(AbstractConnection): + # language=rst + """ + Specifies max-pooling synapses between one or two populations of neurons by keeping + online estimates of maximally firing neurons. + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a ``MaxPool3dConnection`` object. + + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: Depth-wise, horizontal and vertical size of convolutional kernels. + :param stride: Depth-wise, horizontal and vertical stride for convolution. + :param padding: Depth-wise, horizontal and vertical padding for convolution. + :param dilation: Depth-wise, horizontal and vertical dilation for convolution. + + Keyword arguments: + + :param decay: Decay rate of online estimates of average firing activity. + """ + super().__init__(source, target, None, None, 0.0, **kwargs) + + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride) + self.padding = _triple(padding) + self.dilation = _triple(dilation) + + self.register_buffer("firing_rates", torch.zeros(source.s.shape)) + + def compute(self, s: torch.Tensor) -> torch.Tensor: + # language=rst + """ + Compute max-pool pre-activations given spikes using online firing rate + estimates. + + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + self.firing_rates -= self.decay * self.firing_rates + self.firing_rates += s.float().squeeze() + + _, indices = F.max_pool3d( + self.firing_rates, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=True, + ) + + return s.flatten(2).gather(2, indices.flatten(2)).view_as(indices).float() + + def update(self, **kwargs) -> None: + # language=rst + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + # language=rst + """ + No weights -> no normalization. + """ + + def reset_state_variables(self) -> None: + # language=rst + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + self.firing_rates = torch.zeros(self.source.s.shape) + + class LocalConnection(AbstractConnection): # language=rst """ diff --git a/examples/mnist/conv1d_MNIST.py b/examples/mnist/conv1d_MNIST.py new file mode 100644 index 00000000..db5a7fe6 --- /dev/null +++ b/examples/mnist/conv1d_MNIST.py @@ -0,0 +1,184 @@ +### Toy example to test Conv1dConnection (the dataset used is MNIST but each image is raveled (each sample has shape (784,)). + +import argparse +import os +from time import time as t + +import torch +from torchvision import transforms +from tqdm import tqdm + +from bindsnet.datasets import MNIST +from bindsnet.encoding import PoissonEncoder +from bindsnet.learning import PostPre +from bindsnet.network import Network +from bindsnet.network.monitors import Monitor +from bindsnet.network.nodes import DiehlAndCookNodes, Input +from bindsnet.network.topology import Connection, Conv1dConnection + +print() + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--n_epochs", type=int, default=1) +parser.add_argument("--n_test", type=int, default=10000) +parser.add_argument("--n_train", type=int, default=60000) +parser.add_argument("--batch_size", type=int, default=1) +parser.add_argument("--kernel_size", type=int, default=28 * 2) +parser.add_argument("--stride", type=int, default=28) +parser.add_argument("--n_filters", type=int, default=25) +parser.add_argument("--padding", type=int, default=0) +parser.add_argument("--time", type=int, default=50) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--intensity", type=float, default=128.0) +parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument("--update_interval", type=int, default=250) +parser.add_argument("--train", dest="train", action="store_true") +parser.add_argument("--test", dest="train", action="store_false") +parser.add_argument("--plot", dest="plot", action="store_true") +parser.add_argument("--gpu", dest="gpu", action="store_true") +parser.set_defaults(plot=True, gpu=True, train=True) + +args = parser.parse_args() + +seed = args.seed +n_epochs = args.n_epochs +n_test = args.n_test +n_train = args.n_train +batch_size = args.batch_size +kernel_size = args.kernel_size +stride = args.stride +n_filters = args.n_filters +padding = args.padding +time = args.time +dt = args.dt +intensity = args.intensity +progress_interval = args.progress_interval +update_interval = args.update_interval +train = args.train +plot = args.plot +gpu = args.gpu + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +if not train: + update_interval = n_test + +conv_size = int((28 * 28 - kernel_size + 2 * padding) / stride) + 1 +per_class = int((n_filters * conv_size) / 10) + +# Build network. +network = Network() +input_layer = Input(n=28 * 28, shape=(1, 28 * 28), traces=True) + +conv_layer = DiehlAndCookNodes( + n=n_filters * conv_size, + shape=(n_filters, conv_size), + traces=True, +) + +conv_conn = Conv1dConnection( + input_layer, + conv_layer, + kernel_size=kernel_size, + stride=stride, + update_rule=PostPre, + norm=0.4 * kernel_size, + nu=[1e-4, 1e-2], + wmax=1.0, +) + +w = torch.zeros(n_filters, conv_size, n_filters, conv_size) +for fltr1 in range(n_filters): + for fltr2 in range(n_filters): + if fltr1 != fltr2: + for i in range(conv_size): + w[fltr1, i, fltr2, i] = -100.0 + +w = w.view(n_filters * conv_size, n_filters * conv_size) +recurrent_conn = Connection(conv_layer, conv_layer, w=w) + +network.add_layer(input_layer, name="X") +network.add_layer(conv_layer, name="Y") +network.add_connection(conv_conn, source="X", target="Y") +network.add_connection(recurrent_conn, source="Y", target="Y") + +# Voltage recording for excitatory and inhibitory layers. +voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time) +network.add_monitor(voltage_monitor, name="output_voltage") + +if gpu: + network.to("cuda") + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + "../../data/MNIST", + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] + ), +) + +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +# Train the network. +print("Begin training.\n") +start = t() + +inpt_axes = None +inpt_ims = None +spike_ims = None +spike_axes = None +voltage_ims = None +voltage_axes = None + +for epoch in range(n_epochs): + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, + ) + + for step, batch in enumerate(tqdm(train_dataloader)): + # Get next input sample (raveled to have shape (time, batch_size, 1, 28*28)) + if step > n_train: + break + inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28 * 28)} + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + label = batch["label"] + + # Run the network on the input. + network.run(inputs=inputs, time=time, input_time_dim=1) + + network.reset_state_variables() # Reset state variables. + +print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) +print("Training complete.\n") diff --git a/examples/mnist/conv3d_MNIST.py b/examples/mnist/conv3d_MNIST.py new file mode 100644 index 00000000..be0b280e --- /dev/null +++ b/examples/mnist/conv3d_MNIST.py @@ -0,0 +1,207 @@ +### Toy example to test Conv3dConnection (the dataset used is MNIST but with a dimension replicated +### for each image (each sample has size (28, 28, 28)) + +import argparse +import os +from time import time as t + +import torch +from torchvision import transforms +from tqdm import tqdm + +from bindsnet.datasets import MNIST +from bindsnet.encoding import PoissonEncoder +from bindsnet.learning import PostPre +from bindsnet.network import Network +from bindsnet.network.monitors import Monitor +from bindsnet.network.nodes import DiehlAndCookNodes, Input +from bindsnet.network.topology import Connection, Conv3dConnection + +print() + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--n_epochs", type=int, default=1) +parser.add_argument("--n_test", type=int, default=10000) +parser.add_argument("--n_train", type=int, default=60000) +parser.add_argument("--batch_size", type=int, default=1) +parser.add_argument("--kernel_size", type=int, default=16) +parser.add_argument("--stride", type=int, default=4) +parser.add_argument("--n_filters", type=int, default=25) +parser.add_argument("--padding", type=int, default=0) +parser.add_argument("--time", type=int, default=50) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--intensity", type=float, default=128.0) +parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument("--update_interval", type=int, default=250) +parser.add_argument("--train", dest="train", action="store_true") +parser.add_argument("--test", dest="train", action="store_false") +parser.add_argument("--plot", dest="plot", action="store_true") +parser.add_argument("--gpu", dest="gpu", action="store_true") +parser.set_defaults(plot=True, gpu=True, train=True) + +args = parser.parse_args() + +seed = args.seed +n_epochs = args.n_epochs +n_test = args.n_test +n_train = args.n_train +batch_size = args.batch_size +kernel_size = args.kernel_size +stride = args.stride +n_filters = args.n_filters +padding = args.padding +time = args.time +dt = args.dt +intensity = args.intensity +progress_interval = args.progress_interval +update_interval = args.update_interval +train = args.train +plot = args.plot +gpu = args.gpu + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +if not train: + update_interval = n_test + +conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1 +per_class = int((n_filters * conv_size) / 10) + +# Build network. +network = Network() +input_layer = Input(n=28 * 28 * 28, shape=(1, 28, 28, 28), traces=True) + +conv_layer = DiehlAndCookNodes( + n=n_filters * conv_size * conv_size * conv_size, + shape=(n_filters, conv_size, conv_size, conv_size), + traces=True, +) + +conv_conn = Conv3dConnection( + input_layer, + conv_layer, + kernel_size=kernel_size, + stride=stride, + update_rule=PostPre, + norm=0.4 * kernel_size ** 3, + nu=[1e-4, 1e-2], + wmax=1.0, +) + +w = torch.zeros( + n_filters, + conv_size, + conv_size, + conv_size, + n_filters, + conv_size, + conv_size, + conv_size, +) +for fltr1 in range(n_filters): + for fltr2 in range(n_filters): + if fltr1 != fltr2: + for i in range(conv_size): + for j in range(conv_size): + for k in range(conv_size): + w[fltr1, i, j, k, fltr2, i, j, k] = -100.0 + +w = w.view( + n_filters * conv_size * conv_size * conv_size, + n_filters * conv_size * conv_size * conv_size, +) +recurrent_conn = Connection(conv_layer, conv_layer, w=w) + +network.add_layer(input_layer, name="X") +network.add_layer(conv_layer, name="Y") +network.add_connection(conv_conn, source="X", target="Y") +network.add_connection(recurrent_conn, source="Y", target="Y") + +# Voltage recording for excitatory and inhibitory layers. +voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time) +network.add_monitor(voltage_monitor, name="output_voltage") + +if gpu: + network.to("cuda") + + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + "../../data/MNIST", + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] + ), +) + +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +# Train the network. +print("Begin training.\n") +start = t() + +inpt_axes = None +inpt_ims = None +spike_ims = None +spike_axes = None +weights1_im = None +voltage_ims = None +voltage_axes = None + +for epoch in range(n_epochs): + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, + ) + + for step, batch in enumerate(tqdm(train_dataloader)): + # Get next input sample (expanded to have shape (time, batch_size, 1, 28, 28)) + if step > n_train: + break + inputs = { + "X": batch["encoded_image"] + .view(time, batch_size, 1, 28, 28) + .unsqueeze(3) + .repeat(1, 1, 1, 28, 1, 1) + .float() + } + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + label = batch["label"] + + # Run the network on the input. + network.run(inputs=inputs, time=time, input_time_dim=1) + + network.reset_state_variables() # Reset state variables. + +print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) +print("Training complete.\n") diff --git a/examples/mnist/conv_mnist.py b/examples/mnist/conv_mnist.py index 99d75cd5..a23f0b4b 100644 --- a/examples/mnist/conv_mnist.py +++ b/examples/mnist/conv_mnist.py @@ -28,6 +28,7 @@ parser.add_argument("--n_epochs", type=int, default=1) parser.add_argument("--n_test", type=int, default=10000) parser.add_argument("--n_train", type=int, default=60000) +parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--kernel_size", type=int, default=16) parser.add_argument("--stride", type=int, default=4) parser.add_argument("--n_filters", type=int, default=25) @@ -49,6 +50,7 @@ n_epochs = args.n_epochs n_test = args.n_test n_train = args.n_train +batch_size = args.batch_size kernel_size = args.kernel_size stride = args.stride n_filters = args.n_filters @@ -164,14 +166,18 @@ start = t() train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): # Get next input sample. if step > n_train: break - inputs = {"X": batch["encoded_image"].view(time, 1, 1, 28, 28)} + inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28, 28)} if gpu: inputs = {k: v.cuda() for k, v in inputs.items()} label = batch["label"] @@ -180,7 +186,7 @@ network.run(inputs=inputs, time=time, input_time_dim=1) # Optionally plot various simulation information. - if plot: + if plot and batch_size == 1: image = batch["image"].view(28, 28) inpt = inputs["X"].view(time, 784).sum(0).view(28, 28)