Skip to content

Commit 5165254

Browse files
committed
JIT in encoders
1 parent 9719df6 commit 5165254

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

‎ciftools/binary/decoder.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
StringArrayEncoding,
1414
)
1515

16-
1716
defdecode_cif_data(encoded_data: EncodedCIFData) ->Union[np.ndarray, list[str]]:
1817
result=encoded_data["data"]
1918
forencodinginencoded_data["encoding"][::-1]:
@@ -48,7 +47,7 @@ def _decode_delta(data: np.ndarray, encoding: DeltaEncoding) -> np.ndarray:
4847
result[0] +=encoding["origin"]
4948
returnnp.cumsum(result, out=result)
5049

51-
50+
# TODO: JIT
5251
def_decode_integer_packing_signed(data: np.ndarray, encoding: IntegerPackingEncoding) ->np.ndarray:
5352
upper_limit=0x7Fifencoding["byteCount"] ==1else0x7FFF
5453
lower_limit=-upper_limit-1
@@ -70,6 +69,7 @@ def _decode_integer_packing_signed(data: np.ndarray, encoding: IntegerPackingEnc
7069
returnoutput
7170

7271

72+
# TODO: JIT
7373
def_decode_integer_packing_unsigned(data: np.ndarray, encoding: IntegerPackingEncoding) ->np.ndarray:
7474
upper_limit=0xFFifencoding["byteCount"] ==1else0xFFFF
7575
n=len(data)

‎ciftools/binary/encoder.py‎

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
importsys
33
fromtypingimportAny, List, Protocol, Union
44

5+
fromnumbaimportjit
56
importnumpyasnp
67
fromciftools.binary.data_typesimportDataType, DataTypeEnum
78
fromciftools.binary.encoded_dataimportEncodedCIFData
@@ -105,8 +106,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> EncodedCIFData:
105106
classIntegerPacking(BinaryCIFEncoder):
106107
defencode(self, data: np.ndarray) ->EncodedCIFData:
107108

108-
# TODO: must be 32bit integer
109-
109+
# TODO: must be 32bit integer?
110110
packing=_determine_packing(data)
111111
ifpacking.bytesPerElement==4:
112112
returnBYTE_ARRAY.encode(data)
@@ -130,8 +130,6 @@ def encode(self, data: np.ndarray) -> EncodedCIFData:
130130

131131
lower_limit=-upper_limit-1
132132

133-
# TODO: figure out if there is a way to implement this
134-
# better & faster with numpy methods.
135133
_pack_values(data, upper_limit, lower_limit, packed)
136134

137135
byte_array_result=BYTE_ARRAY.encode(packed)
@@ -154,6 +152,7 @@ class _PackingInfo:
154152
bytesPerElement: int
155153

156154

155+
@jit(nopython=True)
157156
def_pack_values(data: np.ndarray, upper_limit: int, lower_limit: int, target: np.ndarray) ->None:
158157
offset=0
159158
forvalueindata:
@@ -200,6 +199,7 @@ def _determine_packing(data: np.ndarray) -> _PackingInfo:
200199
returnpacking
201200

202201

202+
@jit(nopython=True)
203203
def_packing_size_signed(data: np.ndarray, upper_limit: int) ->int:
204204
lower_limit=-upper_limit-1
205205
size=0
@@ -213,6 +213,7 @@ def _packing_size_signed(data: np.ndarray, upper_limit: int) -> int:
213213
returnsize+len(data)
214214

215215

216+
@jit(nopython=True)
216217
def_packing_size_unsigned(data: np.ndarray, upper_limit: int) ->int:
217218
size=0
218219

@@ -299,34 +300,16 @@ def encode(self, data: np.ndarray) -> EncodedCIFData:
299300

300301
classStringArray(BinaryCIFEncoder):
301302
defencode(self, data: Union[np.ndarray, list[str]]) ->EncodedCIFData:
302-
_map=dict()
303-
304303
strings: list[str] = []
305304
offsets= [0]
306305
indices=np.empty(len(data), dtype="<i4")
307306

308-
acc_len=0
309-
310-
fori, sinenumerate(data):
311-
# handle null strings.
312-
ifnots:
313-
indices[i] =-1
314-
continue
315-
316-
index=_map.get(s)
317-
ifindexisNone:
318-
# increment the length
319-
acc_len+=len(s)
320-
321-
# store the string and index
322-
index=len(strings)
323-
strings.append(s)
324-
_map[s] =index
325-
326-
# write the offset
327-
offsets.append(acc_len)
328-
329-
indices[i] =index
307+
_pack_strings(
308+
data,
309+
indices,
310+
strings,
311+
offsets,
312+
)
330313

331314
encoded_offsets=_OFFSET_ENCODER.encode(np.array(offsets, dtype="<i4"))
332315
encoded_data=_DATA_ENCODER.encode(indices)
@@ -341,4 +324,31 @@ def encode(self, data: Union[np.ndarray, list[str]]) -> EncodedCIFData:
341324

342325
returnEncodedCIFData(data=encoded_data["data"], encoding=[encoding])
343326

327+
# TODO: benchmark if JIT helps here
328+
@jit(nopython=False, forceobj=True)
329+
def_pack_strings(data: List[str], indices: np.ndarray, strings: List[str], offsets: List[int]) ->None:
330+
acc_len=0
331+
str_map=dict()
332+
333+
fori, sinenumerate(data):
334+
# handle null strings.
335+
ifnots:
336+
indices[i] =-1
337+
continue
338+
339+
index=str_map.get(s)
340+
ifindexisNone:
341+
# increment the length
342+
acc_len+=len(s)
343+
344+
# store the string and index
345+
index=len(strings)
346+
strings.append(s)
347+
str_map[s] =index
348+
349+
# write the offset
350+
offsets.append(acc_len)
351+
352+
indices[i] =index
353+
344354
STRING_ARRAY=StringArray()

0 commit comments

Comments
(0)