Source code for nabu.pipeline.helical.helical_chunked_regridded

# pylint: skip-file

from os import path
from time import time
import numpy as np
import math
from silx.image.tomography import get_next_power
from scipy import ndimage as nd
import h5py
import silx.io
import copy
from silx.io.url import DataUrl
from ...resources.logger import LoggerOrPrint
from ...resources.utils import is_hdf5_extension, extract_parameters
from ...io.reader_helical import ChunkReaderHelical, get_hdf5_dataset_shape
from ...preproc.flatfield_variable_region import FlatFieldDataVariableRegionUrls as FlatFieldDataHelicalUrls
from ...preproc.distortion import DistortionCorrection
from ...preproc.shift import VerticalShift
from ...preproc.double_flatfield_variable_region import DoubleFlatFieldVariableRegion as DoubleFlatFieldHelical
from ...preproc.phase import PaganinPhaseRetrieval
from ...reconstruction.sinogram import SinoBuilder, SinoNormalization
from ...misc.unsharp import UnsharpMask
from ...misc.histogram import PartialHistogram, hist_as_2Darray
from ..utils import use_options, pipeline_step

from ...resources.utils import extract_parameters
from ..detector_distortion_provider import DetectorDistortionProvider

from .utils import (
    WriterConfiguratorHelical as WriterConfigurator,
)  # .utils is the same as ..utils but internally we retouch the key associated to "tiffwriter" of Writers to

# point to our class which can write tiff with names indexed by the z height above the sample stage in millimiters

from numpy.lib.stride_tricks import sliding_window_view
from ...misc.binning import get_binning_function
from .helical_utils import find_mirror_indexes

from ...preproc.ccd import Log, CCDFilter

from . import gridded_accumulator

# For now we don't have a plain python/numpy backend for reconstruction
Backprojector = None


[docs] class HelicalChunkedRegriddedPipeline: """ Pipeline for "helical" full or half field tomography. Data is processed by chunks. A chunk consists in K+-1 contiguous lines of all the radios which are read at variable height following the translations """ extra_marge_granularity = 4 """ This offers extra reading space to be able to read the redundant part which might be sligtly larger and or require extra border for interpolation """ FlatFieldClass = FlatFieldDataHelicalUrls DoubleFlatFieldClass = DoubleFlatFieldHelical CCDFilterClass = CCDFilter MLogClass = Log PaganinPhaseRetrievalClass = PaganinPhaseRetrieval UnsharpMaskClass = UnsharpMask VerticalShiftClass = VerticalShift SinoBuilderClass = SinoBuilder FBPClass = Backprojector HBPClass = None HistogramClass = PartialHistogram regular_accumulator = None def __init__( self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None, reading_granularity=10, span_info=None, ): """ Initialize a "HelicalChunked" pipeline. Parameters ---------- process_config: `nabu.resources.processcinfig.ProcessConfig` Process configuration. sub_region: tuple Sub-region to process in the volume for this worker, in the format `(start_x, end_x, start_z, end_z)`. logger: `nabu.app.logger.Logger`, optional Logger class extra_options: dict, optional Advanced extra options. phase_margin: tuple, optional Margin to use when performing phase retrieval, in the form ((up, down), (left, right)). See also the documentation of PaganinPhaseRetrieval. If not provided, no margin is applied. reading_granularity: int The data angular span which needs to be read for a reconstruction is read step by step, reading each time a maximum of reading_granularity radios, and doing the preprocessing till phase retrieval for each of these angular groups Notes ------ Using a `phase_margin` results in a lesser number of reconstructed slices. More specifically, if `phase_margin = (V, H)`, then there will be `chunk_size - 2*V` reconstructed slices (if the sub-region is in the middle of the volume) or `chunk_size - V` reconstructed slices (if the sub-region is on top or bottom of the volume). """ self.span_info = span_info self.reading_granularity = reading_granularity self.logger = LoggerOrPrint(logger) self._set_params(process_config, sub_region, extra_options, phase_margin) self._init_pipeline() def _set_params(self, process_config, sub_region, extra_options, phase_margin): self.process_config = process_config self.dataset_info = self.process_config.dataset_info self.processing_steps = self.process_config.processing_steps.copy() self.processing_options = self.process_config.processing_options sub_region = self._check_subregion(sub_region) self.chunk_size = sub_region[-1] - sub_region[-2] self.radios_buffer = None self._set_detector_distortion_correction() self.set_subregion(sub_region) self._set_phase_margin(phase_margin) self._set_extra_options(extra_options) self._callbacks = {} self._steps_name2component = {} self._steps_component2name = {} self._data_dump = {} self._resume_from_step = None @staticmethod def _check_subregion(sub_region): if len(sub_region) < 4: assert len(sub_region) == 2, " at least start_z and end_z are required in subregion" sub_region = (None, None) + sub_region if None in sub_region[-2:]: raise ValueError("Cannot set z_min or z_max to None") return sub_region def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} advanced_options = {} advanced_options.update(extra_options) self.extra_options = advanced_options def _set_phase_margin(self, phase_margin): if phase_margin is None: phase_margin = ((0, 0), (0, 0)) self._phase_margin_up = phase_margin[0][0] self._phase_margin_down = phase_margin[0][1] self._phase_margin_left = phase_margin[1][0] self._phase_margin_right = phase_margin[1][1]
[docs] def set_subregion(self, sub_region): """ Set a sub-region to process. Parameters ---------- sub_region: tuple Sub-region to process in the volume, in the format `(start_x, end_x, start_z, end_z)` or `(start_z, end_z)`. """ sub_region = self._check_subregion(sub_region) dz = sub_region[-1] - sub_region[-2] if dz != self.chunk_size: raise ValueError( "Class was initialized for chunk_size = %d but provided sub_region has chunk_size = %d" % (self.chunk_size, dz) ) self.sub_region = sub_region self.z_min = sub_region[-2] self.z_max = sub_region[-1]
def _compute_phase_kernel_margin(self): """ Get the "margin" to pass to classes like PaganinPhaseRetrieval. In order to have a good accuracy for filter-based phase retrieval methods, we need to load extra data around the edges of each image. Otherwise, a default padding type is applied. """ if not (self.use_radio_processing_margin): self._phase_margin = None return up_margin = self._phase_margin_up down_margin = self._phase_margin_down # Horizontal margin is not implemented left_margin, right_margin = (0, 0) self._phase_margin = ((up_margin, down_margin), (left_margin, right_margin)) @property def use_radio_processing_margin(self): return ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps) def _get_phase_margin(self): if not (self.use_radio_processing_margin): return ((0, 0), (0, 0)) return self._phase_margin @property def phase_margin(self): """ Return the margin for phase retrieval in the form ((up, down), (left, right)) """ return self._get_phase_margin() def _get_process_name(self, kind="reconstruction"): # In the future, might be something like "reconstruction-<ID>" if kind == "reconstruction": return "reconstruction" elif kind == "histogram": return "histogram" return kind def _configure_dump(self, step_name): if step_name not in self.processing_steps: if step_name == "sinogram" and self.process_config._dump_sinogram: fname_full = self.process_config._dump_sinogram_file else: return else: if not self.processing_options[step_name].get("save", False): return fname_full = self.processing_options[step_name]["save_steps_file"] fname, ext = path.splitext(fname_full) dirname, file_prefix = path.split(fname) output_dir = path.join(dirname, file_prefix) file_prefix += str("_%06d" % self._get_image_start_index()) self.logger.info("omitting config in data_dump because of too slow nexus writer ") self._data_dump[step_name] = WriterConfigurator( output_dir, file_prefix, file_format="hdf5", overwrite=True, logger=self.logger, nx_info={ "process_name": step_name, "processing_index": 0, # TODO # "config": {"processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config}, "config": None, "entry": getattr(self.dataset_info.dataset_scanner, "entry", None), }, ) def _configure_data_dumps(self): self.process_config._configure_save_steps() for step_name in self.processing_steps: self._configure_dump(step_name) # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation if self.process_config._dump_sinogram: self._configure_dump("sinogram") # # Callbacks #
[docs] def register_callback(self, step_name, callback): """ Register a callback for a pipeline processing step. Parameters ---------- step_name: str processing step name callback: callable A function. It will be executed once the processing step `step_name` is finished. The function takes only one argument: the class instance. """ if step_name not in self.processing_steps: raise ValueError("'%s' is not in processing steps %s" % (step_name, self.processing_steps)) if step_name in self._callbacks: self._callbacks[step_name].append(callback) else: self._callbacks[step_name] = [callback]
# # Overwritten in inheriting classes # def _get_shape(self, step_name): """ Get the shape to provide to the class corresponding to step_name. """ if step_name == "flatfield": shape = self.radios_subset.shape elif step_name == "double_flatfield": shape = self.radios_subset.shape elif step_name == "phase": shape = self.radios_subset.shape[1:] elif step_name == "ccd_correction": shape = self.gridded_radios.shape[1:] elif step_name == "unsharp_mask": shape = self.radios_subset.shape[1:] elif step_name == "take_log": shape = self.radios.shape elif step_name == "radios_movements": shape = self.radios.shape elif step_name == "sino_normalization": shape = self.radios.shape elif step_name == "sino_normalization_slim": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "one_sino_slim": shape = self.radios.shape[:1] + self.radios.shape[2:] elif step_name == "build_sino": shape = self.radios.shape[:1] + (1,) + self.radios.shape[2:] elif step_name == "reconstruction": shape = self.sino_builder.output_shape[1:] else: raise ValueError("Unknown processing step %s" % step_name) self.logger.debug("Data shape for %s is %s" % (step_name, str(shape))) return shape def _allocate_array(self, shape, dtype, name=None): """this function can be redefined in the derived class which is dedicated to gpu and will return gpu garrays """ return _cpu_allocate_array(shape, dtype, name=name) def _cpu_allocate_array(self, shape, dtype, name=None): """For objects used in the pre-gpu part. They will be always on CPU even in the derived class""" result = np.zeros(shape, dtype=dtype) return result def _allocate_sinobuilder_output(self): return self._cpu_allocate_array(self.sino_builder.output_shape, "f", name="sinos") def _allocate_recs(self, ny, nx): self.n_slices = self.gridded_radios.shape[1] if self.use_radio_processing_margin: self.n_slices -= sum(self.phase_margin[0]) self.recs = self._allocate_array((1, ny, nx), "f", name="recs") self.recs_stack = self._cpu_allocate_array((self.n_slices, ny, nx), "f", name="recs_stack") def _reset_memory(self): pass def _get_read_dump_subregion(self): read_opts = self.processing_options["read_chunk"] if read_opts.get("process_file", None) is None: return None dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"] relative_start_z = self.z_min - dump_start_z relative_end_z = relative_start_z + self.chunk_size # (n_angles, n_z, n_x) subregion = (None, None, relative_start_z, relative_end_z, None, None) return subregion def _check_resume_from_step(self): if self._resume_from_step is None: return read_opts = self.processing_options["read_chunk"] expected_radios_shape = get_hdf5_dataset_shape( read_opts["process_file"], read_opts["process_h5_path"], sub_region=self._get_read_dump_subregion(), ) # TODO check def _init_reader_finalize(self): """ Method called after _init_reader """ self._check_resume_from_step() self._compute_phase_kernel_margin() self._allocate_reduced_gridded_and_subset_radios() def _allocate_reduced_gridded_and_subset_radios(self): shp_h = self.chunk_reader.data.shape[-1] sliding_window_size = self.chunk_size if sliding_window_size % 2 == 0: sliding_window_size += 1 sliding_window_radius = (sliding_window_size - 1) // 2 if sliding_window_radius == 0: n_projs_max = (self.span_info.sunshine_ends - self.span_info.sunshine_starts).max() else: padded_starts = self.span_info.sunshine_starts padded_ends = self.span_info.sunshine_ends padded_starts = np.concatenate( [[padded_starts[0]] * sliding_window_radius, padded_starts, [padded_starts[-1]] * sliding_window_radius] ) starts = sliding_window_view(padded_starts, sliding_window_size).min(axis=-1) padded_ends = np.concatenate( [[padded_ends[0]] * sliding_window_radius, padded_ends, [padded_ends[-1]] * sliding_window_radius] ) ends = sliding_window_view(padded_ends, sliding_window_size).max(axis=-1) n_projs_max = (ends - starts).max() ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin (start_x, end_x, start_z, end_z) = self.sub_region ## and now the gridded ones my_angle_step = abs(np.diff(self.span_info.projection_angles_deg).mean()) self.n_gridded_angles = int(round(360.0 / my_angle_step)) self.my_angles_rad = np.arange(self.n_gridded_angles) * 2 * np.pi / self.n_gridded_angles my_angles_deg = np.rad2deg(self.my_angles_rad) self.mirror_angle_relative_indexes = find_mirror_indexes(my_angles_deg) if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") r_shp_v, r_shp_h = self.whole_radio_shape (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subradio_shape = subr_end_z - subr_start_z, r_shp_h ### these radios are for diagnostic of the translations ( they will be optionally written, for being further used ## by correlation techniques ). Two radios for the first two pass over the first gridded angles self.diagnostic_radios = np.zeros((2,) + subradio_shape, np.float32) self.diagnostic_weights = np.zeros((2,) + subradio_shape, np.float32) self.diagnostic_proj_angle = np.zeros([2], "f") self.diagnostic = { "radios": self.diagnostic_radios, "weights": self.diagnostic_weights, "angles": self.diagnostic_proj_angle, } ## ------- self.gridded_radios = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) self.gridded_cumulated_weights = np.zeros((self.n_gridded_angles,) + subradio_shape, np.float32) self.radios_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) self.radios_weights_subset = np.zeros((self.reading_granularity,) + subradio_shape, np.float32) self.radios = np.zeros( (self.n_gridded_angles,) + ((end_z - down_margin) - (start_z + up_margin), shp_h), np.float32 ) self.radios_weights = np.zeros_like(self.radios) self.radios_slim = self._allocate_array(self._get_shape("one_sino_slim"), "f", name="radios_slim") def _process_finalize(self): """ Method called once the pipeline has been executed """ pass def _get_slice_start_index(self): return self.z_min + self._phase_margin_up _get_image_start_index = _get_slice_start_index # # Pipeline initialization # def _init_pipeline(self): self._get_size_of_a_raw_radio() self._init_reader() self._init_flatfield() self._init_double_flatfield() self._init_weights_field() self._init_ccd_corrections() self._init_phase() self._init_unsharp() self._init_mlog() self._init_sino_normalization() self._init_sino_builder() self._prepare_reconstruction() self._init_reconstruction() self._init_histogram() self._init_writer() self._configure_data_dumps() self._configure_regular_accumulator() def _set_detector_distortion_correction(self): if self.process_config.nabu_config["preproc"]["detector_distortion_correction"] is None: self.detector_corrector = None else: self.detector_corrector = DetectorDistortionProvider( detector_full_shape_vh=self.process_config.dataset_info.radio_dims[::-1], correction_type=self.process_config.nabu_config["preproc"]["detector_distortion_correction"], options=self.process_config.nabu_config["preproc"]["detector_distortion_correction_options"], ) def _configure_regular_accumulator(self): accumulator_cls = gridded_accumulator.GriddedAccumulator self.regular_accumulator = accumulator_cls( gridded_radios=self.gridded_radios, gridded_weights=self.gridded_cumulated_weights, diagnostic_radios=self.diagnostic_radios, diagnostic_weights=self.diagnostic_weights, diagnostic_angles=self.diagnostic_proj_angle, dark=self.flatfield.get_dark(), flat_indexes=self.flatfield._sorted_flat_indices, flats=self.flatfield.flats_stack, weights=self.weights_field.data, double_flat=self.double_flatfield.data, ) return def _get_size_of_a_raw_radio(self): """Once for all we find the shape of a radio. This information will be used in other parts of the code when allocating bunch of data holders """ if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] here_a_file = next(iter(options["files"].values())) here_a_radio = silx.io.get_data(here_a_file) binning_x, binning_z = self._get_binning() if (binning_z, binning_x) != (1, 1): binning_function = get_binning_function((binning_z, binning_x)) here_a_radio = binning_function(here_a_radio) self.whole_radio_shape = here_a_radio.shape return self.whole_radio_shape @use_options("read_chunk", "chunk_reader") def _init_reader(self): if "read_chunk" not in self.processing_steps: raise ValueError("Cannot proceed without reading data") options = self.processing_options["read_chunk"] assert options.get("process_file", None) is None, "Resume not yet implemented in helical pipeline" # dummy initialisation, it will be _set_subregion'ed and set_data_buffer'ed in the loops self.chunk_reader = ChunkReaderHelical( options["files"], sub_region=None, # setting of subregion will be already done by calls to set_subregion detector_corrector=self.detector_corrector, convert_float=True, binning=options["binning"], dataset_subsampling=options["dataset_subsampling"], data_buffer=None, pre_allocate=True, ) self._init_reader_finalize() @use_options("flatfield", "flatfield") def _init_flatfield(self, shape=None): if shape is None: shape = self._get_shape("flatfield") options = self.processing_options["flatfield"] distortion_correction = None if options["do_flat_distortion"]: self.logger.info("Flats distortion correction will be applied") estimation_kwargs = {} estimation_kwargs.update(options["flat_distortion_params"]) estimation_kwargs["logger"] = self.logger distortion_correction = DistortionCorrection( estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn" ) self.flatfield = self.FlatFieldClass( shape, flats=self.dataset_info.flats, darks=self.dataset_info.darks, radios_indices=options["projs_indices"], interpolation="linear", distortion_correction=distortion_correction, radios_srcurrent=options["radios_srcurrent"], flats_srcurrent=options["flats_srcurrent"], detector_corrector=self.detector_corrector, ## every flat will be read at a different heigth ### sub_region=self.sub_region, binning=options["binning"], convert_float=True, ) def _get_binning(self): options = self.processing_options["read_chunk"] binning = options["binning"] if binning is None: return 1, 1 else: return binning def _init_double_flatfield(self): options = self.processing_options["double_flatfield"] binning_x, binning_z = self._get_binning() result_url = None self.double_flatfield = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/double_flatfield/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading double flatfield from %s" % result_url.file_path()) self.double_flatfield = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z, detector_corrector=self.detector_corrector, ) def _init_weights_field(self): options = self.processing_options["double_flatfield"] result_url = None binning_x, binning_z = self.chunk_reader.get_binning() self.weights_field = None if options["processes_file"] not in (None, ""): file_path = options["processes_file"] data_path = (self.dataset_info.hdf5_entry or "entry") + "/weights_field/results/data" if path.exists(file_path) and (data_path in h5py.File(file_path, "r")): result_url = DataUrl(file_path=file_path, data_path=data_path) self.logger.info("Loading weights_field from %s" % result_url.file_path()) self.weights_field = self.DoubleFlatFieldClass( self._get_shape("double_flatfield"), result_url=result_url, binning_x=binning_x, binning_z=binning_z ) def _init_ccd_corrections(self): if "ccd_correction" not in self.processing_steps: return options = self.processing_options["ccd_correction"] median_clip_thresh = options["median_clip_thresh"] self.ccd_correction = self.CCDFilterClass( self._get_shape("ccd_correction"), median_clip_thresh=median_clip_thresh ) @use_options("phase", "phase_retrieval") def _init_phase(self): options = self.processing_options["phase"] # If unsharp mask follows phase retrieval, then it should be done # before cropping to the "inner part". # Otherwise, crop the data just after phase retrieval. if "unsharp_mask" in self.processing_steps: margin = None else: margin = self._phase_margin self.phase_retrieval = self.PaganinPhaseRetrievalClass( self._get_shape("phase"), distance=options["distance_m"], energy=options["energy_kev"], delta_beta=options["delta_beta"], pixel_size=options["pixel_size_m"], padding=options["padding_type"], margin=margin, fftw_num_threads=True, # TODO tune in advanced params of nabu config file ) if self.phase_retrieval.use_fftw: self.logger.debug( "PaganinPhaseRetrieval using FFTW with %d threads" % self.phase_retrieval.fftw.num_threads ) ##@use_options("unsharp_mask", "unsharp_mask") def _init_unsharp(self): if "unsharp_mask" not in self.processing_steps: self.unsharp_mask = None self.unsharp_sigma = 0.0 self.unsharp_coeff = 0.0 self.unsharp_method = "log" else: options = self.processing_options["unsharp_mask"] self.unsharp_sigma = options["unsharp_sigma"] self.unsharp_coeff = options["unsharp_coeff"] self.unsharp_method = options["unsharp_method"] self.unsharp_mask = self.UnsharpMaskClass( self._get_shape("unsharp_mask"), options["unsharp_sigma"], options["unsharp_coeff"], mode="reflect", method=options["unsharp_method"], ) def _init_mlog(self): options = self.processing_options["take_log"] self.mlog = self.MLogClass( self._get_shape("take_log"), clip_min=options["log_min_clip"], clip_max=options["log_max_clip"] ) @use_options("sino_normalization", "sino_normalization") def _init_sino_normalization(self): options = self.processing_options["sino_normalization"] self.sino_normalization = self.SinoNormalizationClass( kind=options["method"], radios_shape=self._get_shape("sino_normalization_slim"), ) def _init_sino_builder(self): options = self.processing_options["reconstruction"] ## build_sino class disappeared disappeared self.sino_builder = self.SinoBuilderClass( radios_shape=self._get_shape("build_sino"), rot_center=options["rotation_axis_position"], halftomo=False, ) self._sinobuilder_copy = False self._sinobuilder_output = None self.sinos = None # this should be renamed, as it could be confused with _init_reconstruction. What about _get_reconstruction_array ? @use_options("reconstruction", "reconstruction") def _prepare_reconstruction(self): options = self.processing_options["reconstruction"] x_s, x_e = options["start_x"], options["end_x"] y_s, y_e = options["start_y"], options["end_y"] self._rec_roi = (x_s, x_e + 1, y_s, y_e + 1) self._allocate_recs(y_e - y_s + 1, x_e - x_s + 1) @use_options("reconstruction", "reconstruction") def _init_reconstruction(self): options = self.processing_options["reconstruction"] if self.sino_builder is None: raise ValueError("Reconstruction cannot be done without build_sino") if self.FBPClass is None: raise ValueError("No usable FBP module was found") rot_center = options["rotation_axis_position"] start_y, end_y, start_x, end_x = self._rec_roi if self.HBPClass is not None: fan_source_distance_meters = self.process_config.nabu_config["reconstruction"]["fan_source_distance_meters"] self.reconstruction_hbp = self.HBPClass( self._get_shape("one_sino_slim"), slice_shape=(end_y - start_y, end_x - start_x), angles=self.my_angles_rad, rot_center=rot_center, extra_options={"axis_correction": np.zeros(self.radios.shape[0], "f")}, axis_source_meters=fan_source_distance_meters, voxel_size_microns=options["voxel_size_cm"][0] * 1.0e4, scale_factor=1.0 / options["voxel_size_cm"][0], ) else: self.reconstruction_hbp = None self.reconstruction = self.FBPClass( self._get_shape("reconstruction"), angles=np.zeros(self.radios.shape[0], "f"), rot_center=rot_center, filter_name=options["fbp_filter_type"], slice_roi=self._rec_roi, # slice_shape = ( end_y-start_y, end_x- start_x ), scale_factor=1.0 / options["voxel_size_cm"][0], padding_mode=options["padding_type"], extra_options={ "scale_factor": 1.0 / options["voxel_size_cm"][0], "axis_correction": np.zeros(self.radios.shape[0], "f"), "clip_outer_circle": options["clip_outer_circle"], }, # "padding_mode": options["padding_type"], }, ) my_options = self.process_config.nabu_config["reconstruction"] if my_options["axis_to_the_center"]: x_s, x_ep1, y_s, y_ep1 = self._rec_roi off_x = -int(round((x_s + x_ep1 - 1) / 2.0 - rot_center)) off_y = -int(round((y_s + y_ep1 - 1) / 2.0 - rot_center)) self.reconstruction.offsets = {"x": off_x, "y": off_y} if options["fbp_filter_type"] is None: self.reconstruction.fbp = self.reconstruction.backproj @use_options("histogram", "histogram") def _init_histogram(self): options = self.processing_options["histogram"] self.histogram = self.HistogramClass(method="fixed_bins_number", num_bins=options["histogram_bins"]) self.histo_stack = [] @use_options("save", "writer") def _init_writer(self, chunk_info=None): options = self.processing_options["save"] file_prefix = options["file_prefix"] output_dir = path.join(options["location"], file_prefix) nx_info = None self._hdf5_output = is_hdf5_extension(options["file_format"]) if chunk_info is not None: d_v, d_h = self.process_config.dataset_info.radio_dims[::-1] h_rels = self._get_slice_start_index() + np.arange(chunk_info.span_v[1] - chunk_info.span_v[0]) fact_mm = self.process_config.dataset_info.pixel_size * 1.0e-3 heights_mm = ( fact_mm * (-self.span_info.z_pix_per_proj[0] + (d_v - 1) / 2 - h_rels) - self.span_info.z_offset_mm ) else: heights_mm = None if self._hdf5_output: fname_start_index = None file_prefix += str("_%06d" % self._get_slice_start_index()) entry = getattr(self.dataset_info.dataset_scanner, "entry", None) nx_info = { "process_name": self._get_process_name(), "processing_index": 0, "config": { "processing_options": self.processing_options, "nabu_config": self.process_config.nabu_config, }, "entry": entry, } self._histogram_processing_index = nx_info["processing_index"] + 1 elif options["file_format"] in ["tif", "tiff"]: fname_start_index = self._get_slice_start_index() self._histogram_processing_index = 1 self._writer_configurator = WriterConfigurator( output_dir, file_prefix, file_format=options["file_format"], overwrite=options["overwrite"], start_index=fname_start_index, heights_above_stage_mm=heights_mm, logger=self.logger, nx_info=nx_info, write_histogram=("histogram" in self.processing_steps), histogram_entry=getattr(self.dataset_info.dataset_scanner, "entry", "entry"), ) self.writer = self._writer_configurator.writer self._writer_exec_args = self._writer_configurator._writer_exec_args self._writer_exec_kwargs = self._writer_configurator._writer_exec_kwargs self.histogram_writer = self._writer_configurator.get_histogram_writer() def _apply_expand_fact(self, t): if t is not None: t = t * self.chunk_reader.dataset_subsampling return t def _expand_slice(self, subchunk_slice): start, stop, step = subchunk_slice.start, subchunk_slice.stop, subchunk_slice.step if step is None: step = 1 start, stop, step = list(map(self._apply_expand_fact, [start, stop, step])) result_slice = slice(start, stop, step) return result_slice def _extract_preprocess_with_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, output): """Read, and apply dark+ff to, a small angular domain corresponding to the slice argument sub_total_prange_slice without refilling the holes. """ if self.chunk_reader.dataset_subsampling > 1: subsampling_file_slice = self._expand_slice(sub_total_prange_slice) else: subsampling_file_slice = sub_total_prange_slice my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 floating_start_z = subr_start_z_list.min() floating_end_z = subr_end_z_list.max() floating_subregion = None, None, floating_start_z, floating_end_z self._reset_reader_subregion(floating_subregion) self.chunk_reader.load_data(overwrite=True, sub_total_prange_slice=sub_total_prange_slice) my_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] data_raw = self.chunk_reader.data[: len(my_indexes)] if (self.flatfield is not None) or (self.double_flatfield is not None): sub_regions_per_radio = [self.trimmed_floating_subregion] * len(my_indexes) if self.flatfield is not None: self.flatfield.normalize_radios(data_raw, my_indexes, sub_regions_per_radio) if self.double_flatfield is not None: self.double_flatfield.apply_double_flatfield_for_sub_regions(data_raw, sub_regions_per_radio) source_start_x, source_end_x, source_start_z, sources_end_z = self.trimmed_floating_subregion if self.weights_field is not None: data_weight = self.weights_field.data[source_start_z:sources_end_z] else: data_weight = None for data_read, list_subr_start_z, list_subr_end_z, fract_shit, x_shift, data_target in zip( data_raw, subr_start_z_list, subr_end_z_list, fract_complement_shifts_v, x_shifts_list, output ): _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shit, list_subr_start_z, list_subr_end_z, source_start_z, sources_end_z, x_shift=x_shift, ) def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info): my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice] fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice] x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice] (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region subr_start_z_list = subr_start_z - my_integer_shifts_v subr_end_z_list = subr_end_z - my_integer_shifts_v + 1 self._reset_reader_subregion((None, None, subr_start_z_list.min(), subr_end_z_list.max())) dtasrc_start_x, dtasrc_end_x, dtasrc_start_z, dtasrc_end_z = self.trimmed_floating_subregion self.chunk_reader.load_data(overwrite=True, sub_total_prange_slice=sub_total_prange_slice) if self.chunk_reader.dataset_subsampling > 1: subsampling_file_slice = self._expand_slice(sub_total_prange_slice) else: subsampling_file_slice = sub_total_prange_slice my_subsampled_indexes = self.chunk_reader._sorted_files_indices[subsampling_file_slice] data_raw = self.chunk_reader.data[: len(my_subsampled_indexes)] self.regular_accumulator.extract_preprocess_with_flats( subchunk_slice, my_subsampled_indexes, chunk_info, np.array((subr_start_z, subr_end_z), "i"), np.array((dtasrc_start_z, dtasrc_end_z), "i"), data_raw, )
[docs] def binning_expanded(self, region): binning_x, binning_z = self.chunk_reader.get_binning() binnings = [binning_x] * 2 + [binning_z] * 2 res = [None if tok is None else tok * fact for tok, fact in zip(region, binnings)] return res
def _reset_reader_subregion(self, floating_subregion): if self._resume_from_step is None: binning_x, binning_z = self.chunk_reader.get_binning() start_x, end_x, start_z, end_z = floating_subregion trimmed_start_z = max(0, start_z) trimmed_end_z = min(self.whole_radio_shape[0], end_z) my_buffer_height = trimmed_end_z - trimmed_start_z if self.radios_buffer is None or my_buffer_height > self.safe_buffer_height: self.safe_buffer_height = end_z - start_z assert ( self.safe_buffer_height >= my_buffer_height ), "This should always be true, if not contact the developer" self.radios_buffer = None self.radios_buffer = np.zeros( (self.reading_granularity + self.extra_marge_granularity,) + (self.safe_buffer_height, self.whole_radio_shape[1]), np.float32, ) self.trimmed_floating_subregion = start_x, end_x, trimmed_start_z, trimmed_end_z self.chunk_reader._set_subregion(self.binning_expanded(self.trimmed_floating_subregion)) self.chunk_reader._init_reader() self.chunk_reader._loaded = False self.chunk_reader.set_data_buffer(self.radios_buffer[:, :my_buffer_height, :], pre_allocate=False) else: message = "Resume not yet implemented in helical pipeline" raise RuntimeError(message) def _ccd_corrections(self, radios=None): if radios is None: radios = self.gridded_radios if hasattr(self.ccd_correction, "median_clip_correction_multiple_images"): self.ccd_correction.median_clip_correction_multiple_images(radios) else: _tmp_radio = self._cpu_allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio") for i in range(radios.shape[0]): self.ccd_correction.median_clip_correction(radios[i], output=_tmp_radio) radios[i][:] = _tmp_radio[:] def _retrieve_phase(self): if "unsharp_mask" in self.processing_steps: for i in range(self.gridded_radios.shape[0]): self.gridded_radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) else: for i in range(self.gridded_radios.shape[0]): self.radios[i] = self.phase_retrieval.apply_filter(self.gridded_radios[i]) def _nophase_put_to_radios(self, target, source): ((up_margin, down_margin), (left_margin, right_margin)) = self.phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(target.shape[0]): target[i] = source[i][zslice, xslice] def _apply_unsharp(): ((up_margin, down_margin), (left_margin, right_margin)) = self._phase_margin zslice = slice(up_margin or None, -down_margin or None) xslice = slice(left_margin or None, -right_margin or None) for i in range(self.radios.shape[0]): self.radios[i] = self.unsharp_mask.unsharp(self.gridded_radios[i])[zslice, xslice] def _take_log(self): self.mlog.take_logarithm(self.radios) @pipeline_step("sino_normalization", "Normalizing sinograms") def _normalize_sinos(self, radios=None): if radios is None: radios = self.radios sinos = radios.transpose((1, 0, 2)) self.sino_normalization.normalize(sinos) def _dump_sinogram(self, radios=None): if radios is None: radios = self.radios self._dump_data_to_file("sinogram", data=radios) @pipeline_step("sino_builder", "Building sinograms") def _build_sino(self): self.sinos = self.radios_slim def _filter(self): rot_center = self.processing_options["reconstruction"]["rotation_axis_position"] self.reconstruction.sino_filter.filter_sino( self.radios_slim, mirror_indexes=self.mirror_angle_relative_indexes, rot_center=rot_center, output=self.radios_slim, ) def _build_sino(self): self.sinos = self.radios_slim def _reconstruct(self, sinos=None, chunk_info=None, i_slice=0): if sinos is None: sinos = self.sinos use_hbp = self.process_config.nabu_config["reconstruction"]["use_hbp"] if not use_hbp: if i_slice == 0: self.reconstruction.set_custom_angles_and_axis_corrections( self.my_angles_rad, np.zeros_like(self.my_angles_rad) ) self.reconstruction.backprojection(sinos, output=self.recs[0]) self.recs[0].get(self.recs_stack[i_slice]) else: if self.reconstruction_hbp is None: raise ValueError("You requested the hierchical backprojector but the module could not be imported") self.reconstruction_hbp.backprojection(sinos, output=self.recs_stack[i_slice]) def _compute_histogram(self, data=None, i_slice=None, num_slices=None): if self.histogram is None: return if data is None: data = self.recs my_histo = self.histogram.compute_histogram(data.ravel()) self.histo_stack.append(my_histo) if i_slice == num_slices - 1: self.recs_histogram = self.histogram.merge_histograms(self.histo_stack) self.histo_stack.clear() def _write_data(self, data=None, counter=[0]): if data is None: data = self.recs_stack my_kw_args = copy.copy(self._writer_exec_kwargs) if "config" in my_kw_args: self.logger.info("omitting config in writer because of too slow nexus writer ") my_kw_args["config"] = {"test": counter[0]} counter[0] += 1 self.writer.write(data, *self._writer_exec_args, **my_kw_args) self.logger.info("Wrote %s" % self.writer.get_filename()) self._write_histogram() def _write_histogram(self): if "histogram" not in self.processing_steps: return self.logger.info("Saving histogram") self.histogram_writer.write( hist_as_2Darray(self.recs_histogram), self._get_process_name(kind="histogram"), processing_index=self._histogram_processing_index, config={ "file": path.basename(self.writer.get_filename()), "bins": self.processing_options["histogram"]["histogram_bins"], }, ) def _dump_data_to_file(self, step_name, data=None): if step_name not in self._data_dump: return self.logger.info(f"DUMP step_name={step_name}") if data is None: data = self.radios writer = self._data_dump[step_name] self.logger.info("Dumping data to %s" % writer.fname) writer.write_data(data)
[docs] def balance_weights(self): options = self.processing_options["reconstruction"] rot_center = options["rotation_axis_position"] self.radios_weights[:] = rebalance(self.radios_weights, self.my_angles_rad, rot_center) # When standard scans are incomplete, due to motors errors, some angular range # is missing short of 360 degrees. # The weight accounting correctly deal with it, but still the padding # procedure with theta+180 data may fall on empty data # and this may cause problems, coming from the ramp filter, # in half tomo. # To correct this we complete with what we have at hand from the nearest # non empty data # to_be_filled = [] for i in range(len(self.radios_weights) - 1, 0, -1): if self.radios_weights[i].sum(): break to_be_filled.append(i) for i in to_be_filled: self.radios[i] = self.radios[to_be_filled[-1] - 1]
def _post_primary_data_reduction(self, i_slice): """This will be used in the derived class to transfer data to gpu""" self.radios_slim[:] = self.radios[:, i_slice, :]
[docs] def reset_translation_diagnostics_accumulators(self): self.diagnostic_radios[:] = 0 self.diagnostic_weights[:] = 0 self.diagnostic_proj_angle[1] = (2**30) * 3.14 self.diagnostic_proj_angle[0] = (2**30 - 1) * 3.14
[docs] def process_chunk(self, sub_region=None): self._private_process_chunk(sub_region=sub_region) self._process_finalize()
def _private_process_chunk(self, sub_region=None): assert sub_region is not None, "sub_region argument is mandatory in helical pipeline" self.set_subregion(sub_region) self.reset_translation_diagnostics_accumulators() # self._allocate_reduced_radios() # self._allocate_reduced_gridded_and_subset_radios() (subr_start_x, subr_end_x, subr_start_z, subr_end_z) = self.sub_region span_v = subr_start_z + self._phase_margin_up, subr_end_z - self._phase_margin_down chunk_info = self.span_info.get_chunk_info(span_v) self._reset_memory() self._init_writer(chunk_info) self._configure_data_dumps() proj_num_start, proj_num_end = chunk_info.angle_index_span n_granularity = self.reading_granularity pnum_start_list = list(np.arange(proj_num_start, proj_num_end, n_granularity)) pnum_end_list = pnum_start_list[1:] + [proj_num_end] my_first_pnum = proj_num_start self.gridded_cumulated_weights[:] = 0 self.gridded_radios[:] = 0 for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list): start_in_chunk = pnum_start - my_first_pnum end_in_chunk = pnum_end - my_first_pnum self._read_data_and_apply_flats( slice(pnum_start, pnum_end), slice(start_in_chunk, end_in_chunk), chunk_info ) self.gridded_radios[:] /= self.gridded_cumulated_weights if "flatfield" in self._data_dump: paganin_margin = self._phase_margin_up if paganin_margin: data_to_dump = self.gridded_radios[:, paganin_margin:-paganin_margin, :] else: data_to_dump = self.gridded_radios self._dump_data_to_file("flatfield", data_to_dump) if self.process_config.nabu_config["pipeline"]["skip_after_flatfield_dump"]: return if "ccd_correction" in self.processing_steps: self._ccd_corrections() if ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps): self._retrieve_phase() if "unsharp_mask" in self.processing_steps: self._apply_unsharp() else: self._nophase_put_to_radios(self.radios, self.gridded_radios) self.logger.info(" LOG ") self._nophase_put_to_radios(self.radios_weights, self.gridded_cumulated_weights) # print( " processing steps ", self.processing_steps ) # ['read_chunk', 'flatfield', 'double_flatfield', 'take_log', 'reconstruction', 'save'] if "take_log" in self.processing_steps: self._take_log() self.logger.info(" BALANCE ") self.balance_weights() num_slices = self.radios.shape[1] self.logger.info(" NORMALIZE") self._normalize_sinos() self._dump_sinogram() if "reconstruction" in self.processing_steps: for i_slice in range(num_slices): self._post_primary_data_reduction(i_slice) # charge on self.radios_slim self._filter() self.apply_weights(i_slice) self._build_sino() self._reconstruct(chunk_info=chunk_info, i_slice=i_slice) self._compute_histogram(i_slice=i_slice, num_slices=num_slices) self._write_data()
[docs] def apply_weights(self, i_slice): """radios_slim is on gpu""" n_provided_angles = self.radios_slim.shape[0] for first_angle_index in range(0, n_provided_angles, self.num_weight_radios_per_app): end_angle_index = min(n_provided_angles, first_angle_index + self.num_weight_radios_per_app) self._d_radios_weights[: end_angle_index - first_angle_index].set( self.radios_weights[first_angle_index:end_angle_index, i_slice] ) self.radios_slim[first_angle_index:end_angle_index] *= self._d_radios_weights[ : end_angle_index - first_angle_index ]
[docs] @classmethod def estimate_required_memory( cls, process_config, reading_granularity=None, chunk_size=None, margin_v=0, span_info=None ): """ Estimate the memory (RAM) needed for a reconstruction. Parameters ----------- process_config: `ProcessConfig` object Data structure with the processing configuration chunk_size: int, optional Size of a "radios chunk", i.e "delta z". A radios chunk is a 3D array of shape (n_angles, chunk_size, n_x) If set to None, then chunk_size = n_z Notes ----- It seems that Cuda does not allow allocating and/or transferring more than 16384 MiB (17.18 GB). If warn_from_GB is not None, then the result is in the form (estimated_memory_GB, warning) where warning is a boolean indicating wheher memory allocation/transfer might be problematic. """ dataset = process_config.dataset_info nabu_config = process_config.nabu_config processing_steps = process_config.processing_steps Nx, Ny = dataset.radio_dims total_memory_needed = 0 # Read data # ---------- # gridded part tmp_angles_deg = np.rad2deg(process_config.processing_options["reconstruction"]["angles"]) tmp_my_angle_step = abs(np.diff(tmp_angles_deg).mean()) my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean()) n_gridded_angles = int(round(360.0 / my_angle_step)) binning_z = nabu_config["dataset"]["binning_z"] projections_subsampling = nabu_config["dataset"]["projections_subsampling"] # the gridded target total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the gridded weights total_memory_needed += Nx * (2 * margin_v + chunk_size) * n_gridded_angles * 4 # the read grain total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) total_memory_needed += ( (reading_granularity + cls.extra_marge_granularity) * (2 * margin_v + chunk_size + 2) * Nx * 4 ) # the preprocessed radios, their weigth and the buffer used for balancing ( total of three buffer of the same size plus mask plus temporary) total_memory_needed += 5 * (Nx * (chunk_size) * n_gridded_angles) * 4 if "flatfield" in processing_steps: # Flat-field is done in-place, but still need to load darks/flats n_darks = len(dataset.darks) n_flats = len(dataset.flats) darks_size = n_darks * Nx * (2 * margin_v + chunk_size) * 2 # uint16 flats_size = n_flats * Nx * (2 * margin_v + chunk_size) * 4 # f32 total_memory_needed += darks_size + flats_size if "ccd_correction" in processing_steps: total_memory_needed += Nx * (2 * margin_v + chunk_size) * 4 # Phase retrieval # --------------- if "phase" in processing_steps: # Phase retrieval is done image-wise, so near in-place, but needs to # allocate some images, fft plans, and so on Nx_p = get_next_power(2 * Nx) Ny_p = get_next_power(2 * (2 * margin_v + chunk_size)) img_size_real = 2 * 4 * Nx_p * Ny_p img_size_cplx = 2 * 8 * ((Nx_p * Ny_p) // 2 + 1) total_memory_needed += 2 * img_size_real + 3 * img_size_cplx # Reconstruction # --------------- reconstructed_volume_size = 0 if "reconstruction" in processing_steps: ## radios_slim is used to process one slice at once, It will be on the gpu ## and cannot be reduced further, therefore no need to estimate it. ## Either it passes or it does not. #### if radios_and_sinos: #### togtal_memory_needed += data_volume_size # radios + sinos rec_config = process_config.processing_options["reconstruction"] Nx_rec = rec_config["end_x"] - rec_config["start_x"] + 1 Ny_rec = rec_config["end_y"] - rec_config["start_y"] + 1 Nz_rec = chunk_size // binning_z ## the volume is used to reconstruct for each chunk reconstructed_volume_size = Nx_rec * Ny_rec * Nz_rec * 4 # float32 total_memory_needed += reconstructed_volume_size return total_memory_needed
# target_central_slicer, source_central_slicer = overlap_logic( subr_start_z, subr_end_z, dtasrc_start_z, dtasrcs_end_z )
[docs] def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """determines the useful lines which can be transferred from the dtasrc_start_z:dtasrc_end_z range targeting the range subr_start_z: subr_end_z .................. """ t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z my_start = max(0, dtasrc_start_z - subr_start_z) my_end = min(t_h, dtasrc_end_z - subr_start_z) if my_start >= my_end: return None, None target_central_slicer = slice(my_start, my_end) my_start = max(0, subr_start_z - dtasrc_start_z) my_end = min(s_h, subr_end_z - dtasrc_start_z) assert my_start < my_end, "Overlap_logic logic error" dtasrc_central_slicer = slice(my_start, my_end) return target_central_slicer, dtasrc_central_slicer
[docs] def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): """.......... and the missing ranges which possibly could be obtained by extension padding""" t_h = subr_end_z - subr_start_z s_h = dtasrc_end_z - dtasrc_start_z if dtasrc_start_z <= subr_start_z: target_lower_padding = None else: target_lower_padding = slice(0, dtasrc_start_z - subr_start_z) if dtasrc_end_z >= subr_end_z: target_upper_padding = None else: target_upper_padding = slice(dtasrc_end_z - subr_end_z, None) return target_lower_padding, target_upper_padding
def _fill_in_chunk_by_shift_crop_data( data_target, data_read, fract_shit, my_subr_start_z, my_subr_end_z, dtasrc_start_z, dtasrc_end_z, x_shift=0.0, extension_padding=True, ): """given a freshly read cube of data, it dispatches every slice to its proper vertical position and proper radio by shifting, cropping, and extending if necessary""" data_read_precisely_shifted = nd.interpolation.shift(data_read, (-fract_shit, x_shift), order=1, mode="nearest")[ :-1 ] target_central_slicer, dtasrc_central_slicer = overlap_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if None not in [target_central_slicer, dtasrc_central_slicer]: data_target[target_central_slicer] = data_read_precisely_shifted[dtasrc_central_slicer] target_lower_slicer, target_upper_slicer = padding_logic( my_subr_start_z, my_subr_end_z - 1, dtasrc_start_z, dtasrc_end_z - 1 ) if extension_padding: if target_lower_slicer is not None: data_target[target_lower_slicer] = data_read_precisely_shifted[0] if target_upper_slicer is not None: data_target[target_upper_slicer] = data_read_precisely_shifted[-1] else: if target_lower_slicer is not None: data_target[target_lower_slicer] = 1.0e-6 if target_upper_slicer is not None: data_target[target_upper_slicer] = 1.0e-6
[docs] def shift(arr, shift, fill_value=0.0): """trivial horizontal shift. Contrarily to scipy.ndimage.interpolation.shift, this shift does not cut the tails abruptly, but by interpolation """ result = np.zeros_like(arr) num1 = int(math.floor(shift)) num2 = num1 + 1 partition = shift - num1 for num, factor in zip([num1, num2], [(1 - partition), partition]): if num > 0: result[:, :num] += fill_value * factor result[:, num:] += arr[:, :-num] * factor elif num < 0: result[:, num:] += fill_value * factor result[:, :num] += arr[:, -num:] * factor else: result[:] += arr * factor return result
[docs] def rebalance(radios_weights, angles, ax_pos): """rebalance the weights, within groups of equivalent (up to multiple of 180), data pixels""" balanced = np.zeros_like(radios_weights) n_span = int(math.ceil(angles[-1] - angles[0]) / np.pi) center = (radios_weights.shape[-1] - 1) / 2 nloop = balanced.shape[0] for i in range(nloop): w_res = balanced[i] angle = angles[i] for i_half_turn in range(-n_span - 1, n_span + 2): if i_half_turn == 0: w_res[:] += radios_weights[i] continue shifted_angle = angle + i_half_turn * np.pi insertion_index = np.searchsorted(angles, shifted_angle) if insertion_index in [0, angles.shape[0]]: if insertion_index == 0: continue else: if shifted_angle > 2 * np.pi: continue myimage = radios_weights[-1] else: partition = (shifted_angle - angles[insertion_index - 1]) / ( angles[insertion_index] - angles[insertion_index - 1] ) myimage = (1.0 - partition) * radios_weights[insertion_index - 1] + partition * radios_weights[ insertion_index ] if i_half_turn % 2 == 0: w_res[:] += myimage else: myimage = np.fliplr(myimage) w_res[:] += shift(myimage, (2 * ax_pos - 2 * center)) mask = np.equal(0, radios_weights) balanced[:] = radios_weights / balanced balanced[mask] = 0 return balanced