# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Author: Stephan J. Ginthoer

# How to use this script:
# Run as: python reconstruct.py path_to_data_folder
# Python version >= 3.5

# path_to_data_folder must contain the following files:
# projections_<length>.raw
# Usually these files are created with the preprocess.py script
# <length> is the number of datapoints per projection
# The raw files are dumped 2D numpy arrays
# Each row (axis=1) contains one projection (float64 * <length>)
# recorded at certain phi, theta angles
# Projections are blocked by same angle of phi
# Both angles are in ascending order
# Example:
# Phi Theta
# 1 degrees 1 degrees
# 1 degrees 2 degrees
# ...................
# 1 degrees N degrees
# 2 degrees 1 degrees
# ...................
# 2 degrees N degrees
# ...................
# M degrees N degrees
# Final 3D image will be exported as: image_3d.raw

import numpy as np
import scipy as sp
import skimage.transform as t
import pathlib as pl
import sys

#---------Settings for reconstruction-------------------------

# This variable controlls the shape of the reconstruction pipeline
# The keys in the dictionary name the used projection (window) lengths
# The corresponding values indicate how often the reconstruction should be repeated (list length)
# and the relaxation factor for each iteration (number in list)
reconstruction_config = {128: [0.05]*9, 1024: [0.05]*1}

# The angles the projections were recorded at
theta = np.linspace(0, 180, 30, endpoint=False)
phi = np.linspace(0, 180, 30, endpoint=False)

#---------------End of Settings-------------------------------

path = pl.Path(sys.argv[1])

# Upscales (interpolates) a 2D image (array) to double the size
# Input must be quadratic
def upscale2d(image_2d, by=2):
    old_length = image_2d.shape[0]
    new_length = int(by) * old_length

    old_x = np.linspace(0.5, old_length - 0.5, old_length)
    new_x = np.linspace(0.5/by, old_length - 0.5/by, new_length)

    f = sp.interpolate.interp2d(
        old_x, old_x, image_2d, kind='cubic', fill_value=0)
    new_image = f(new_x, new_x)

    return new_image


def reconstruct(projections_all_res, config, thetas):
    projection_lengths = list(config)
    projection_lengths.sort()
    projection_lengths = np.array(projection_lengths)

    first_length = projection_lengths[0]

    image_2d = np.zeros((first_length, first_length))

    for projection_length in projection_lengths:
        last_length = image_2d.shape[0]
        if not last_length == projection_length:
            image_2d = upscale2d(image_2d, projection_length/last_length)

        relax_rates = config[projection_length]
        projections_one_res = projections_all_res[projection_length]
        shifts = np.zeros_like(projections_one_res[:, 0])

        for relax_rate in relax_rates:
            image = t.iradon_sart(projections_one_res.transpose(),
                                  theta=thetas, image=image_2d, relaxation=relax_rate,
                                  projection_shifts=shifts)

    return image_2d


def groupProjections(all_projections, group_index):
    group_dict = dict()
    no_theta_vals = len(theta)

    for length in all_projections:
        group_dict[length] = all_projections[length][group_index * no_theta_vals : 
                                                     ((group_index + 1) * no_theta_vals)]

    return group_dict


def reshape3dProjections(projections):
    image_width = projections.shape[1]
    projection_count = projections.shape[2]
    return projections.reshape((image_width, projection_count)).transpose()


def normalizeToInt16(arr, min_, max_):
    factor = 32000.0 / (max_ - min_)
    new_arr = np.zeros_like(arr, dtype=np.int16)
    new_arr[:] = (arr - min_) * factor
    return new_arr


def helper(projections):
    return reconstruct(projections, reconstruction_config, theta)


def helper3d(projections):
    image = t.iradon_sart(projections.transpose(), theta=phi)
    image = t.iradon_sart(projections.transpose(), theta=phi, image=image)
    return image


def runPipeline():
    all_projections = dict()

    projection_lengths = list(reconstruction_config)
    projection_lengths.sort()

    for projection_length in projection_lengths:
        file = np.fromfile(
            path / ('proj_' + str(projection_length) + '.raw'), dtype=np.float64)
        file_length = file.shape[0]

        projection = file.reshape((int(file_length / projection_length), projection_length))

        all_projections[projection_length] = projection

    phi_indexs = range(0, len(phi))

    grouped_projections = list(
        map(lambda phi_index: groupProjections(all_projections, phi_index), phi_indexs))

    images_2d_list = list(map(helper, grouped_projections))
    all_images_2d = np.stack(images_2d_list, axis=2)

    image_width = all_images_2d.shape[0]
    projections_3d_list = list(
        map(reshape3dProjections, np.split(all_images_2d, range(1, image_width), axis=0)))

    image_3d_list = list(map(helper3d, projections_3d_list))

    max_ = max(list(map(lambda x: x.max(), image_3d_list)))
    min_ = min(list(map(lambda x: x.min(), image_3d_list)))

    image_3d_list_int16 = list(map(
        lambda z_slice: normalizeToInt16(z_slice, min_, max_), image_3d_list))
    image_3d = np.stack(image_3d_list_int16, axis=2)
    image_3d.tofile(path / 'image_3d.raw')


def main():
    runPipeline()


if __name__ == '__main__':
    main()
