Skip to content

Commit 02b8195

Browse files
committed
support passing arrays to FieldDesc + CategoryDesc and CategoryWriter as @DataClass
1 parent e20b6a6 commit 02b8195

File tree

4 files changed

+72
-19
lines changed

4 files changed

+72
-19
lines changed

‎ciftools/binary/writer.py‎

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,32 @@ def _encode_field(field: FieldDesc, data: list[_ContextData], total_count: int)
101101
offset=0
102102
for_dindata:
103103
d=_d.data
104-
foriinrange(_d.count):
105-
p=presence(d, i)
106-
ifpisnotValuePresenceEnum.Present:
107-
mask[offset] =p
108-
ifis_native:
109-
array[offset] =None
110-
all_present=False
111-
else:
112-
mask[offset] =ValuePresenceEnum.Present
113-
array[offset] =field.value(d, i)
114-
115-
offset+=1
104+
105+
arrays=field.arrays(d)
106+
ifarraysisnotNone:
107+
iflen(arrays.values) !=_d.count:
108+
raiseValueError(f"values provided in arrays() must have the same length as the category count field")
109+
110+
ifarrays.maskisnotNone:
111+
iflen(arrays.mask) !=_d.count:
112+
raiseValueError(f"mask provided in arrays() must have the same length as the category count field")
113+
mask[offset:offset+_d.count] =arrays.mask
114+
offset+=_d.count
115+
116+
else:
117+
# TODO: use numba JIT for this
118+
foriinrange(_d.count):
119+
p=presence(d, i)
120+
ifpisnotValuePresenceEnum.Present:
121+
mask[offset] =p
122+
ifis_native:
123+
array[offset] =None
124+
all_present=False
125+
else:
126+
mask[offset] =ValuePresenceEnum.Present
127+
array[offset] =field.value(d, i)
128+
129+
offset+=1
116130

117131
encoder=field.encoder(data[0].data) iflen(data) >0else_BYTE_ARRAY_ENCODER
118132
encoded=encoder.encode_cif_data(array)

‎ciftools/writer/base.py‎

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
importabc
2-
fromtypingimportAny, Union
2+
fromdataclassesimportdataclass
3+
fromtypingimportAny, List, Optional, Union
34

45
importnumpyasnp
56
fromciftools.binary.encoding.impl.binary_cif_encoderimportBinaryCIFEncoder
67
fromciftools.cif_format.value_presenceimportValuePresenceEnum
78

9+
@dataclass
10+
classFieldArrays:
11+
values: Union[np.ndarray, List[str]]
12+
# uint8 array, 0 = defined, 1 = ., 2 = ?
13+
mask: Optional[np.ndarray] =None
814

915
classFieldDesc(abc.ABC):
1016
name: str
@@ -17,6 +23,10 @@ def create_array(self, total_count: int) -> Union[np.ndarray, list]:
1723
defvalue(self, data: Any, i: int) ->Any:
1824
pass
1925

26+
@abc.abstractmethod
27+
defarrays(self, data: Any) ->Optional[FieldArrays]:
28+
pass
29+
2030
@abc.abstractmethod
2131
defencoder(self, data: Any) ->BinaryCIFEncoder:
2232
pass
@@ -25,12 +35,13 @@ def encoder(self, data: Any) -> BinaryCIFEncoder:
2535
defpresence(self, data: any, i: int) ->ValuePresenceEnum:
2636
pass
2737

28-
38+
@dataclass
2939
classCategoryDesc:
3040
name: str
3141
fields: list[FieldDesc]
3242

3343

44+
@dataclass
3445
classCategoryWriter:
3546
data: any
3647
count: int

‎ciftools/writer/fields.py‎

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
fromciftools.binary.encoding.impl.binary_cif_encoderimportBinaryCIFEncoder
55
fromciftools.binary.encoding.impl.encoders.string_arrayimportSTRING_ARRAY_CIF_ENCODER
66
fromciftools.cif_format.value_presenceimportValuePresenceEnum
7-
fromciftools.writer.baseimportFieldDesc
7+
fromciftools.writer.baseimportFieldArrays, FieldDesc
88

99
_STRING_ARRAY_ENCODER=BinaryCIFEncoder([STRING_ARRAY_CIF_ENCODER])
1010

@@ -23,24 +23,30 @@ def create_array(self, total_count: int):
2323
defencoder(self, data: Any):
2424
return_STRING_ARRAY_ENCODER
2525

26+
defarrays(self, data: Any) ->Optional[FieldArrays]:
27+
returnself._arrays(data) ifself._arraysisnotNoneelseNone
28+
2629
def__init__(
2730
self,
2831
name: str,
2932
value: Callable[[Any, int], Optional[str]],
3033
presence: Optional[Callable[[Any, int], Optional[ValuePresenceEnum]]] =None,
34+
arrays: Optional[Callable[[Any], FieldArrays]] =None,
3135
) ->None:
3236
self.name=name
3337
self._value=value
3438
self._presence=presence
39+
self._arrays=arrays
3540

3641

3742
defstring_field(
3843
*,
3944
name: str,
4045
value: Callable[[Any, int], Optional[str]],
4146
presence: Optional[Callable[[Any, int], Optional[ValuePresenceEnum]]] =None,
47+
arrays: Optional[Callable[[Any], FieldArrays]] =None,
4248
) ->FieldDesc:
43-
return_StringFieldDesc(name=name, value=value, presence=presence)
49+
return_StringFieldDesc(name=name, value=value, presence=presence, arrays=arrays)
4450

4551

4652
# TODO: derive from FieldDesc
@@ -57,27 +63,33 @@ def presence(self, data: any, i: int) -> ValuePresenceEnum:
5763
defcreate_array(self, total_count: int):
5864
returnnp.empty(total_count, dtype=self._dtype)
5965

66+
defarrays(self, data: Any) ->Optional[FieldArrays]:
67+
returnself._arrays(data) ifself._arraysisnotNoneelseNone
68+
6069
def__init__(
6170
self,
6271
name: str,
6372
value: Callable[[Any, int], Optional[Union[int, float]]],
6473
dtype: np.dtype,
6574
encoder: Callable[[Any], BinaryCIFEncoder],
6675
presence: Optional[Callable[[Any, int], Optional[ValuePresenceEnum]]] =None,
76+
arrays: Optional[Callable[[Any], FieldArrays]] =None
6777
) ->None:
6878
self.name=name
6979
self._value=value
7080
self._dtype=dtype
7181
self._encoder=encoder
7282
self._presence=presence
83+
self._arrays=arrays
7384

7485

7586
defnumber_field(
7687
*,
7788
name: str,
78-
value: Callable[[Any, int], Optional[Union[int, float]]],
89+
value: Optional[Callable[[Any, int], Optional[Union[int, float]]]] =None,
7990
dtype: np.dtype,
8091
encoder: Callable[[Any], BinaryCIFEncoder],
8192
presence: Optional[Callable[[Any, int], Optional[ValuePresenceEnum]]] =None,
93+
arrays: Optional[Callable[[Any], FieldArrays]] =None,
8294
) ->FieldDesc:
83-
return_NumberFieldDesc(name=name, value=value, dtype=dtype, encoder=encoder, presence=presence)
95+
return_NumberFieldDesc(name=name, value=value, dtype=dtype, encoder=encoder, presence=presence, arrays=arrays)

‎tests/_encoding.py‎

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
fromciftools.binary.encoding.impl.encoders.integer_packingimportINTEGER_PACKING_CIF_ENCODER
99
fromciftools.binary.writerimportBinaryCIFWriter
1010
fromciftools.cif_format.binary.fileimportBinaryCIFFile
11-
fromciftools.writer.baseimportCategoryDesc, CategoryWriter, CategoryWriterProvider, FieldDesc, OutputStream
11+
fromciftools.writer.baseimportCategoryDesc, CategoryWriter, CategoryWriterProvider, FieldArrays, FieldDesc, OutputStream
1212
fromciftools.writer.fieldsimportnumber_field, string_field
1313

1414

@@ -102,6 +102,17 @@ def lattice_value_getter(lid: int):
102102
value=lambdadata, i: data.volume[i],
103103
)
104104
)
105+
fields.append(
106+
number_field(
107+
name=f"volume_array",
108+
dtype="f4",
109+
encoder=lambda_: BinaryCIFEncoder(
110+
[FixedPointCIFEncoder(1000), DELTA_CIF_ENCODER, INTEGER_PACKING_CIF_ENCODER]
111+
),
112+
value=lambdadata, i: data.volume[i],
113+
arrays=lambdadata: FieldArrays(values=data.volume),
114+
)
115+
)
105116
fields.append(string_field(name="annotation", value=lambdadata, i: data.annotation[i]))
106117

107118
returnTestCategoryWriter(ctx, self.length, TestCategoryDesc("volume", fields))
@@ -169,6 +180,11 @@ def test(self):
169180
compare=np.allclose(test_data.volume, volume, atol=1e-3)
170181
self.assertTrue(compare, "Volume did not match original data")
171182

183+
volume_array=volume_and_lattices.get_column("volume_array").__dict__["_values"]
184+
print("Volume Array (parsed): "+str(volume_array))
185+
compare=np.allclose(test_data.volume, volume_array, atol=1e-3)
186+
self.assertTrue(compare, "Volume Array did not match original data")
187+
172188
forlattice_idinlattice_ids:
173189
print("Lattice: "+str(lattice_id))
174190
lattice_value=volume_and_lattices.get_column("lattice_"+str(lattice_id)).__dict__["_values"]

0 commit comments

Comments
(0)