22import sys
33from typing import Any , List , Protocol , Union
44
5+ from numba import jit
56import numpy as np
67from ciftools .binary .data_types import DataType , DataTypeEnum
78from ciftools .binary .encoded_data import EncodedCIFData
@@ -105,8 +106,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> EncodedCIFData:
105106class IntegerPacking (BinaryCIFEncoder ):
106107def encode (self , data : np .ndarray ) -> EncodedCIFData :
107108
108- # TODO: must be 32bit integer
109-
109+ # TODO: must be 32bit integer?
110110packing = _determine_packing (data )
111111if packing .bytesPerElement == 4 :
112112return BYTE_ARRAY .encode (data )
@@ -130,8 +130,6 @@ def encode(self, data: np.ndarray) -> EncodedCIFData:
130130
131131lower_limit = - upper_limit - 1
132132
133- # TODO: figure out if there is a way to implement this
134- # better & faster with numpy methods.
135133_pack_values (data , upper_limit , lower_limit , packed )
136134
137135byte_array_result = BYTE_ARRAY .encode (packed )
@@ -154,6 +152,7 @@ class _PackingInfo:
154152bytesPerElement : int
155153
156154
155+ @jit (nopython = True )
157156def _pack_values (data : np .ndarray , upper_limit : int , lower_limit : int , target : np .ndarray ) -> None :
158157offset = 0
159158for value in data :
@@ -200,6 +199,7 @@ def _determine_packing(data: np.ndarray) -> _PackingInfo:
200199return packing
201200
202201
202+ @jit (nopython = True )
203203def _packing_size_signed (data : np .ndarray , upper_limit : int ) -> int :
204204lower_limit = - upper_limit - 1
205205size = 0
@@ -213,6 +213,7 @@ def _packing_size_signed(data: np.ndarray, upper_limit: int) -> int:
213213return size + len (data )
214214
215215
216+ @jit (nopython = True )
216217def _packing_size_unsigned (data : np .ndarray , upper_limit : int ) -> int :
217218size = 0
218219
@@ -299,34 +300,16 @@ def encode(self, data: np.ndarray) -> EncodedCIFData:
299300
300301class StringArray (BinaryCIFEncoder ):
301302def encode (self , data : Union [np .ndarray , list [str ]]) -> EncodedCIFData :
302- _map = dict ()
303-
304303strings : list [str ] = []
305304offsets = [0 ]
306305indices = np .empty (len (data ), dtype = "<i4" )
307306
308- acc_len = 0
309-
310- for i , s in enumerate (data ):
311- # handle null strings.
312- if not s :
313- indices [i ] = - 1
314- continue
315-
316- index = _map .get (s )
317- if index is None :
318- # increment the length
319- acc_len += len (s )
320-
321- # store the string and index
322- index = len (strings )
323- strings .append (s )
324- _map [s ] = index
325-
326- # write the offset
327- offsets .append (acc_len )
328-
329- indices [i ] = index
307+ _pack_strings (
308+ data ,
309+ indices ,
310+ strings ,
311+ offsets ,
312+ )
330313
331314encoded_offsets = _OFFSET_ENCODER .encode (np .array (offsets , dtype = "<i4" ))
332315encoded_data = _DATA_ENCODER .encode (indices )
@@ -341,4 +324,31 @@ def encode(self, data: Union[np.ndarray, list[str]]) -> EncodedCIFData:
341324
342325return EncodedCIFData (data = encoded_data ["data" ], encoding = [encoding ])
343326
327+ # TODO: benchmark if JIT helps here
328+ @jit (nopython = False , forceobj = True )
329+ def _pack_strings (data : List [str ], indices : np .ndarray , strings : List [str ], offsets : List [int ]) -> None :
330+ acc_len = 0
331+ str_map = dict ()
332+
333+ for i , s in enumerate (data ):
334+ # handle null strings.
335+ if not s :
336+ indices [i ] = - 1
337+ continue
338+
339+ index = str_map .get (s )
340+ if index is None :
341+ # increment the length
342+ acc_len += len (s )
343+
344+ # store the string and index
345+ index = len (strings )
346+ strings .append (s )
347+ str_map [s ] = index
348+
349+ # write the offset
350+ offsets .append (acc_len )
351+
352+ indices [i ] = index
353+
344354STRING_ARRAY = StringArray ()
0 commit comments