MRI_Filtering_Revised.ipynb
#%% 
import matplotlib.pyplot as plt
import scipy.ndimage
import numpy as np
import cv2
import os
from scipy.stats import t
from statistics import NormalDist

#%% 
def register_images(reference_gray, transformed_gray):
    # blur the original images
    reference_gray_blurred = cv2.bilateralFilter(reference_gray, d=11, # diameter of filter 11
                                                 sigmaColor=5,
                                                 sigmaSpace=5)
    transformed_gray_blurred = cv2.bilateralFilter(transformed_gray, d=11, # diameter of filter 11
                                                   sigmaColor=5,
                                                   sigmaSpace=5)

    # Calculate the phase correlation between the images
    # graphing(transformed_gray_blurred)
    result = cv2.phaseCorrelate(reference_gray_blurred[:128, 128:384], transformed_gray_blurred[:128, 128:384])

    # Extract the translation parameters (shift)
    dx, dy = result[0]

    if np.sqrt(dx**2 + dy**2) < 50:
        # Warp the transformed image to align with the reference image
        rows, cols = reference_gray.shape
        M = np.float32([[1, 0, dx], [0, 1, dy]])
        aligned_image = cv2.warpAffine(transformed_gray, M, (cols, rows))

        return aligned_image
    else:
        return transformed_gray

def graphing(image):
    plt.figure(figsize=(5, 5), dpi=100)
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    plt.show()


def graph_set(columns, ims, tittles):
    fig = plt.figure(figsize=(12, 12))

    num_im = np.shape(ims)[0]
    rows = round(num_im / columns)

    for i in range(1, columns*rows +1):
        fig.add_subplot(rows, columns, i)
        plt.imshow(ims[i], cmap='Greys_r')
        plt.title(i)
        plt.axis('off')
    plt.show()


def up_resolution(image, sample_factor):
    new_image = scipy.ndimage.zoom(image, sample_factor, order=3)
    return np.uint8(new_image)
#%% 
# read in some images for testing
filename = r'file directory to Raw_BMP'

# sets of files
image_sets = []
for i in os.listdir(filename):
    image_sets.append(i)

# read in just two of the sets
dataset = []
for image_set in image_sets:
    temp_list = []
    im_filename = filename + "\\" + image_set
    for images in os.listdir(im_filename):
        dataset.append(cv2.imread(im_filename + "\\" + images, 0))


# change shape of array to be correct
dataset = np.reshape(dataset, (13, 30, 128, 256))
np.shape(dataset)


# create dataset of upsampled images
upsampled_dataset = np.zeros((np.shape(dataset)[0], np.shape(dataset)[1], np.shape(dataset)[2]*2, np.shape(dataset)[3]*2), dtype='float32')

for folder_set in range(np.shape(dataset)[0]):
    for image in range(np.shape(dataset)[1]):
        upsampled_dataset[folder_set, image, :, :] = up_resolution(dataset[folder_set, image, :, :], 2)

print(np.shape(upsampled_dataset))
#%% 
# register images using a portion of the images
aligned_images = upsampled_dataset.copy()

# graphing(aligned_images[6, 0, :128, 128:384])

for i in range(np.shape(aligned_images)[1]-1):
    for j in range(np.shape(aligned_images)[0]):
        ref_im = aligned_images[6, i, :, :]
        move_im = aligned_images[j, i, :, :]

        aligned_images[j,i, :, :] = register_images(ref_im, move_im)


# graph_set(5, aligned_images[:, 2, :, 128:384], image_sets)

#%% 
# compare to nonlocal means filter - this is the final set of images

def enhance_image(temp_dataset, image_num):
    # correction for two different sets of averages
    if image_num > 4:
        filter_strength = 10
    else:
        filter_strength = 5

    # format is image set, image number, first 1st dimension of image, 2nd dimension of image
    nl_image = cv2.fastNlMeansDenoisingMulti(temp_dataset,
                                             imgToDenoiseIndex=image_num,
                                             temporalWindowSize=3, # 1 plus the surrounding images, must be odd
                                             h=filter_strength, # filter strength
                                             templateWindowSize=7,
                                             searchWindowSize=21)
    # enhance the edges
    bd_im = cv2.bilateralFilter(nl_image,
                                d=5, # diameter of filter
                                sigmaColor=15,
                                sigmaSpace=15)

    # contrast enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(10,10))
    intense_image = clahe.apply(bd_im)

    return intense_image

#%% 
# Denoise images
processed_images = np.zeros((9, 256, 256))
for i in range(11):
    if i == 0 or i == 10:
        continue

    # nonlocal means filtering
    new_im = enhance_image(temp_dataset=aligned_images[:, 2, :, 128:384].astype('uint8'),
                           image_num=i)

    processed_images[i-1, :, :] = new_im
#%% 
# find segmentation of the spinal cord from first image
segmented_image = processed_images[0, :, :].copy()
graphing(segmented_image)

# threshold image
ret,th1 = cv2.threshold(segmented_image[25:125, 80:180],76,255,cv2.THRESH_BINARY)

# find contours
graphing(th1)
contours, hierarchy = cv2.findContours(th1.astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

# Draw contours
sp_mask = np.zeros((th1.shape[0], th1.shape[1]), dtype=np.uint8)
cv2.drawContours(sp_mask, contours, 16, 255, 1, cv2.LINE_8, hierarchy, 0)
sp_mask_outline = sp_mask.copy()
cv2.fillConvexPoly(sp_mask, contours[16], 255)
graphing(sp_mask)
graphing(sp_mask_outline)

#%% 
# calculate unique regions in the spinal cord

def welch_critical_t(df, alpha):
    return t.ppf(1 - alpha, df)


num_images = 4

averaged_compressed_intensity = np.mean(processed_images[:num_images, :, :], axis=0)
averaged_decompressed_intensity = np.mean(processed_images[-num_images:, :, :], axis=0)

signal = averaged_compressed_intensity - averaged_decompressed_intensity

plt.imshow(signal[25:125, 80:180], cmap='Greys', vmin=0, vmax=np.max(signal[25:125, 80:180]))
plt.show()


# only look at values above two sigma
vals = signal[25:125, 80:180].flatten()
norm = NormalDist.from_samples(vals)
signal[signal <= 2.0*norm.stdev] = 0

plt.hist(vals, 255)
plt.show()

# mask the images
result = cv2.bitwise_and(signal[25:125, 80:180].astype('uint8'), sp_mask)
result[result > 1] = 255 # make the patches uniform in color

# find each patch and remove any smaller than 10 pixels
contours, hierarchy = cv2.findContours(result, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Draw contours
masks = []
for i in range(len(contours)):
    mask = np.zeros((th1.shape[0], th1.shape[1]), dtype=np.uint8)
    cv2.drawContours(mask, contours, i, 255, 1, cv2.LINE_8, hierarchy, 0)
    cv2.fillConvexPoly(mask, contours[i], 255)
    masks.append(mask)

# for mask in masks:
#     graphing(mask)

result = cv2.addWeighted(result, 1, sp_mask_outline, 0.2, 0)

plt.imshow(result, cmap='Greys_r', vmin=0, vmax=np.max(signal[25:125, 80:180]))
plt.show()

# cv2.imwrite('Masks.tiff', signal[25:125, 80:180])
#%% 
def roi_points(oim, ellipse_center, ellipse_axes, ellipse_angle):
    temp_im = oim.copy()

    # Create the rotated rectangle and ellipse using OpenCV functions
    ellipse_points = cv2.ellipse2Poly(ellipse_center, ellipse_axes, ellipse_angle, 0, 360, 10)

    # Convert points to integer values
    ellipse_points = np.int0(ellipse_points)

    # Create masks for the shapes
    ellipse_mask = np.zeros_like(temp_im)

    # fill polygons to enclose shape
    cv2.fillPoly(ellipse_mask, [ellipse_points], 255)

    # Find average intensity of enclosed ellipse
    locations = np.fliplr(cv2.findNonZero((ellipse_mask>0).astype(np.uint8)).squeeze())
    ellipse_intensities = []
    for location in locations:
        ellipse_intensities.append(temp_im[location[0], location[1]])

    ave_intensity = np.mean(ellipse_intensities)
    std_intensity = np.std(ellipse_intensities)
    std_err_mean = std_intensity / np.sqrt(len(locations))

    # Display the original image with the drawn shapes
    cv2.drawContours(temp_im, [ellipse_points], -1, 255, 1)
    # cv2.drawContours(t_temp, [ellipse_points], -1, np.max(t_images), 1)


    # cv2.imwrite('Zoomed_Ellipse_%d.tiff' %i, temp_im[25:125, 80:180])
#     plt.imshow(t_temp[25:125, 80:180], cmap='Greys')
#     plt.show()
    return ave_intensity, std_err_mean, temp_im

def roi_masks(oim, func_mask):
    masked_result = cv2.bitwise_and(oim[25:125, 80:180].astype('uint8'), func_mask)

    locations = np.nonzero(masked_result != 0)

    roi_intensities = (masked_result[locations[0], locations[1]])

    ave_intensity = np.mean(roi_intensities)
    err = np.std(roi_intensities) / np.sqrt(len(locations))

    return ave_intensity, err

#%% 
portion_images = processed_images.copy()

# t_temp = t_images.copy()

intensities = []
errors = []


for mask in masks:
    temp_intensities = []
    temp_errors = []
    for im in portion_images:
        intensity, error = roi_masks(im, mask)

        temp_intensities.append(intensity)
        temp_errors.append(error)

    intensities.append(temp_intensities)
    errors.append(temp_errors)

temp_intensities = []
temp_errors = []
for i in range(len(portion_images)):
    # ellipsoid for control distortion: 'orange'
    i6, err6, modified_image = roi_points(portion_images[i, :, :],
                                          ellipse_center=(129, 67),
                                          ellipse_axes=(4, 4),
                                          ellipse_angle=0)

    temp_intensities.append(i6)
    temp_errors.append(err6)

    # if i == 0 or i == 8:
    #     cv2.imwrite('Zoomed_Ellipse_Regions_%d.bmp' %i, modified_image[25:125, 80:180])

    # plt.imshow(modified_image[25:125, 80:180], cmap='Greys', vmin=0, vmax=255)
    # plt.show()

# add control region
intensities.append(temp_intensities)
errors.append(temp_errors)

# convert to numpy arrays
intensities = np.array(intensities)
errors = np.array(errors)

# try renormalizing according to average of first four data points
average_intensities = []
SEM_intensities = []
for i in intensities:
    temp = i[:4]
    average_intensities.append(np.mean(temp))
    SEM_intensities.append(np.std(temp) / np.sqrt(len(temp)))

normalized_intensities = []
normalized_errors = []
for i in range(len(intensities)):
    normalized_intensities.append(intensities[i, :] / average_intensities[i])
    normalized_errors.append( np.sqrt( (errors[i, :] / average_intensities[i])**2  + (intensities[i, :] / average_intensities[i]**2)**2 * SEM_intensities[i]**2) )

normalized_intensities = np.array(normalized_intensities)
normalized_errors = np.array(normalized_errors)


# # Graph the intensities
# fig, ax1 = plt.subplots()
# x = np.linspace(1, len(intensities)-1, 10)
#
# ax1.set_xlabel('Image Number')
# ax1.set_ylabel('Relative Intensity')
# # ax1.set_yticks([40, 60, 80, 100, 120, 140, 160, 180, 200])
# colors = ['red', 'blue', 'yellow', 'black', 'green', 'orange', 'black']
# for i in range(7):
#     # print(i)
#     # print(intensities[i])
#     # print(errors[i])
#     # ax1.errorbar(x[1:10], intensities[i, :], yerr=errors[i, :], color = colors[i])
#     ax1.plot(x[1:10], intensities[i, :] / intensities[i, 0], color = colors[i])
#     ax1.set_ylim([0, 1.6])
#     ax1.fill_between(x[1:10], intensities[i, :] + errors[i, :], intensities[i, :] - errors[i, :], alpha=0.1, color = colors[i])
# plt.show()


# Graph the normalized intensities
fig, ax1 = plt.subplots()
x = np.linspace(1, len(normalized_intensities)-1, 10)

ax1.set_xlabel('Image Number')
ax1.set_ylabel('Relative Intensity')
# ax1.set_yticks([40, 60, 80, 100, 120, 140, 160, 180, 200])
colors = ['red', 'blue', 'yellow', 'black', 'green', 'orange', 'black']
for i in range(7):
    # print(i)
    # print(intensities[i])
    # print(errors[i])
    # ax1.errorbar(x[1:10], intensities[i, :], yerr=errors[i, :], color = colors[i])

    ax1.plot(x[1:10], normalized_intensities[i, :], color = colors[i])
    ax1.set_ylim([0, 1.6])
    ax1.fill_between(x[1:10], normalized_intensities[i, :] + normalized_errors[i, :], normalized_intensities[i, :] - normalized_errors[i, :], alpha=0.1, color = colors[i])
plt.show()

for i in range(7):
    print(i)
    print(f'Intensity and error of region {i}')

    for j in range(9):
        print(normalized_intensities[i, j], end="\t")
        print(normalized_errors[i, j])

# for mask in masks:
#     graphing(mask)
#%% 
# draw the final mask with the center control region
ellipse_points = cv2.ellipse2Poly((129, 67), (4, 4), 0, 0, 360, 10)
ellipse_points = np.int0(ellipse_points)
ellipse_mask = np.zeros_like(segmented_image)
cv2.fillPoly(ellipse_mask, [ellipse_points], np.max(signal[25:125, 80:180]))
graphing(ellipse_mask[25:125, 80:180])
result += ellipse_mask[25:125, 80:180].astype('uint8')
# graphing(result)
plt.imshow(result, cmap='Greys_r', vmin=0, vmax=np.max(signal[25:125, 80:180]))
plt.imsave('Final_Mask.tiff', result, cmap='Greys', vmin=0, vmax=np.max(signal[25:125, 80:180]))
plt.show()
#%% 
# draw the rois over the images
full_mask = np.zeros(np.shape(signal[25:125, 80:180]))

full_mask += ellipse_mask[25:125, 80:180]

for mask in masks:
    full_mask += mask

# find the contours
contours, hierarchy = cv2.findContours(full_mask.astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

first = portion_images[0, 25:125, 80:180].copy()
last = portion_images[-1, 25:125, 80:180].copy()

for i in range(len(contours)):
    cv2.drawContours(first, contours, i, 255, 1, cv2.LINE_8, hierarchy, 0)
    cv2.drawContours(last, contours, i, 255, 1, cv2.LINE_8, hierarchy, 0)

# plt.imshow(first, cmap='Greys', vmin=0, vmax=255)
# plt.imsave('First_Frame.tiff', first, cmap='Greys_r', vmin=0, vmax=255)
# plt.show()
# plt.imshow(last, cmap='Greys', vmin=0, vmax=255)
# plt.imsave('Last_Frame.tiff', last, cmap='Greys_r', vmin=0, vmax=255)
# plt.show()
# #
graphing(first)
graphing(last)