import os
import posixpath
from multiprocessing.pool import ThreadPool
import numpy as np
from silx.io.dictdump import dicttonx, nxtodict
from ..misc.binning import binning as image_binning
from ..io.utils import get_first_hdf5_entry
from ..pipeline.config_validators import optional_tuple_of_floats_validator, optional_positive_integer_validator
from .cli_configs import ShrinkConfig
from .utils import parse_params_values
[docs]
def access_nested_dict(dict_, path, default=None):
items = [s for s in path.split(posixpath.sep) if len(s) > 0]
if len(items) == 1:
return dict_.get(items[0], default)
if items[0] not in dict_:
return default
return access_nested_dict(dict_[items[0]], posixpath.sep.join(items[1:]))
[docs]
def set_nested_dict_value(dict_, path, val):
dirname, basename = posixpath.split(path)
sub_dict = access_nested_dict(dict_, dirname)
sub_dict[basename] = val
[docs]
def shrink_dataset(input_file, output_file, binning=None, subsampling=None, entry=None, n_threads=1):
entry = entry or get_first_hdf5_entry(input_file)
data_dict = nxtodict(input_file, path=entry, dereference_links=False)
to_subsample = [
"control/data",
"instrument/detector/count_time",
"instrument/detector/data",
"instrument/detector/image_key",
"instrument/detector/image_key_control",
"sample/rotation_angle",
"sample/x_translation",
"sample/y_translation",
"sample/z_translation",
]
detector_data = access_nested_dict(data_dict, "instrument/detector/data")
if detector_data is None:
raise ValueError("No data found in %s entry %s" % (input_file, entry))
if binning is not None:
def _apply_binning(img_res_tuple):
img, res = img_res_tuple
res[:] = image_binning(img, binning)
data_binned = np.zeros(
(detector_data.shape[0], detector_data.shape[1] // binning[0], detector_data.shape[2] // binning[1]),
detector_data.dtype,
)
with ThreadPool(n_threads) as tp:
tp.map(_apply_binning, zip(detector_data, data_binned))
detector_data = data_binned
set_nested_dict_value(data_dict, "instrument/detector/data", data_binned)
if subsampling is not None:
for item_path in to_subsample:
item_val = access_nested_dict(data_dict, item_path)
if item_val is not None:
set_nested_dict_value(data_dict, item_path, item_val[::subsampling])
dicttonx(data_dict, output_file, h5path=entry)
[docs]
def shrink_cli():
args = parse_params_values(ShrinkConfig, parser_description="Shrink a NX dataset")
if not (os.path.isfile(args["input_file"])):
print("No such file: %s" % args["input_file"])
exit(1)
if os.path.isfile(args["output_file"]):
print("Output file %s already exists, not overwriting it" % args["output_file"])
exit(1)
binning = optional_tuple_of_floats_validator("", "binning", args["binning"]) # pylint: disable=E1121
if binning is not None:
binning = tuple(map(int, binning))
subsampling = optional_positive_integer_validator("", "subsampling", args["subsampling"]) # pylint: disable=E1121
shrink_dataset(
args["input_file"],
args["output_file"],
binning=binning,
subsampling=subsampling,
entry=args["entry"],
n_threads=args["threads"],
)
return 0
if __name__ == "__main__":
shrink_cli()