Source code for nabu.app.double_flatfield

import numpy as np
from ..preproc.double_flatfield import DoubleFlatField
from ..preproc.flatfield import FlatField
from ..io.writer import NXProcessWriter
from ..resources.dataset_analyzer import analyze_dataset
from ..resources.nxflatfield import update_dataset_info_flats_darks
from ..resources.logger import Logger, LoggerOrPrint
from .cli_configs import DFFConfig
from .utils import parse_params_values


[docs] class DoubleFlatFieldChunks: def __init__( self, dataset_path, output_file, chunk_size=100, sigma=None, do_flatfield=True, h5_entry=None, logger=None ): self.logger = LoggerOrPrint(logger) self.dataset_info = analyze_dataset(dataset_path, extra_options={"hdf5_entry": h5_entry}, logger=logger) self.chunk_size = min(chunk_size, self.dataset_info.radio_dims[-1]) self.do_flatfield = bool(do_flatfield) if self.do_flatfield: update_dataset_info_flats_darks(self.dataset_info, flatfield_mode=True) self.output_file = output_file self.sigma = sigma if sigma is not None and abs(sigma) > 1e-5 else None def _get_config(self): conf = { "dataset": self.dataset_info.location, "entry": self.dataset_info.hdf5_entry or None, "dff_sigma": self.sigma, "do_flatfield": self.do_flatfield, } return conf def _read_projections(self, chunk_size, start_idx=0): reader_kwargs = {"sub_region": (slice(None), slice(start_idx, start_idx + chunk_size), slice(None))} if self.dataset_info.kind == "edf": reader_kwargs = {"n_reading_threads": 4} self.reader = self.dataset_info.get_reader(**reader_kwargs) self.projections = self.reader.load_data() def _init_flatfield(self, start_z=None, end_z=None): if not self.do_flatfield: return chunk_size = end_z - start_z if start_z is not None else self.chunk_size self.flatfield = FlatField( (self.dataset_info.n_angles, chunk_size, self.dataset_info.radio_dims[0]), flats={k: arr[start_z:end_z, :] for k, arr in self.dataset_info.flats.items()}, darks={k: arr[start_z:end_z, :] for k, arr in self.dataset_info.darks.items()}, radios_indices=sorted(self.dataset_info.projections.keys()), ) def _apply_flatfield(self, start_z=None, end_z=None): if self.do_flatfield: self._init_flatfield(start_z=start_z, end_z=end_z) self.flatfield.normalize_radios(self.projections) def _init_dff(self): self.double_flatfield = DoubleFlatField( self.projections.shape, input_is_mlog=False, output_is_mlog=False, average_is_on_log=self.sigma is not None, sigma_filter=self.sigma, )
[docs] def compute_double_flatfield(self): """ Compute the double flatfield for the current dataset. """ n_z = self.dataset_info.radio_dims[-1] chunk_size = self.chunk_size n_steps = n_z // chunk_size extra_step = bool(n_z % chunk_size) res = np.zeros(self.dataset_info.radio_dims[::-1]) for i in range(n_steps): self.logger.debug("Computing DFF batch %d/%d" % (i + 1, n_steps + int(extra_step))) subregion = (None, None, i * chunk_size, (i + 1) * chunk_size) self._read_projections(chunk_size, start_idx=i * chunk_size) self._apply_flatfield(start_z=i * chunk_size, end_z=(i + 1) * chunk_size) self._init_dff() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[subregion[-2] : subregion[-1]] = dff[:] # Need to initialize objects with a different shape if extra_step: curr_idx = (i + 1) * self.chunk_size self.logger.debug("Computing DFF batch %d/%d" % (i + 2, n_steps + int(extra_step))) self._read_projections(n_z - curr_idx, start_idx=curr_idx) self._apply_flatfield(start_z=(i + 1) * chunk_size, end_z=n_z) self._init_dff() dff = self.double_flatfield.compute_double_flatfield(self.projections, recompute=True) res[curr_idx:] = dff[:] return res
[docs] def write_double_flatfield(self, arr): """ Write the double flatfield image to a file """ writer = NXProcessWriter( self.output_file, entry=self.dataset_info.hdf5_entry or "entry", filemode="a", overwrite=True, ) writer.write(arr, "double_flatfield", config=self._get_config()) self.logger.info("Wrote %s" % writer.fname)
[docs] def dff_cli(): args = parse_params_values( DFFConfig, parser_description="A command-line utility for computing the double flatfield of a dataset." ) logger = Logger("nabu_double_flatfield", level=args["loglevel"], logfile="nabu_double_flatfield.log") output_file = args["output"] dff = DoubleFlatFieldChunks( args["dataset"], output_file, chunk_size=args["chunk_size"], sigma=args["sigma"], do_flatfield=bool(args["flatfield"]), h5_entry=args["entry"] or None, logger=logger, ) dff_image = dff.compute_double_flatfield() dff.write_double_flatfield(dff_image) return 0
if __name__ == "__main__": dff_cli()