# -*- coding: utf-8 -*-

import nrrd
import slicerio
import unittest


class TestSegmentationRoundtrip(unittest.TestCase):
    """
    Test segmentation writing/reading roundtrip.
    """

    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_segmentation_read(self):
        """Test segmentation reading"""

        input_segmentation_filepath = slicerio.get_testdata_file('Segmentation.seg.nrrd')

        segmentation = slicerio.read_segmentation(input_segmentation_filepath, skip_voxels=True)

        number_of_segments = len(segmentation["segments"])
        self.assertEqual(number_of_segments, 7)

        segment_names = slicerio.segment_names(segmentation)
        self.assertEqual(segment_names[0], 'ribs')
        self.assertEqual(segment_names[1], 'cervical vertebral column')
        self.assertEqual(segment_names[2], 'thoracic vertebral column')
        self.assertEqual(segment_names[3], 'lumbar vertebral column')
        self.assertEqual(segment_names[4], 'right lung')
        self.assertEqual(segment_names[5], 'left lung')
        self.assertEqual(segment_names[6], 'tissue')

        segment = slicerio.segment_from_name(segmentation, segment_names[4])
        self.assertEqual(segment['id'], 'Segment_5')
        self.assertEqual(segment['name'], 'right lung')
        self.assertEqual(segment['nameAutoGenerated'], True)
        self.assertEqual(segment['color'], [0.0862745, 0.772549, 0.278431])
        self.assertEqual(segment['colorAutoGenerated'], False)
        self.assertEqual(segment['labelValue'], 5)
        self.assertEqual(segment['layer'], 0)
        self.assertEqual(segment['extent'], [0, 124, 0, 127, 0, 33])
        self.assertEqual(segment['status'], 'inprogress')

        terminology = segment["terminology"]
        self.assertEqual(terminology['contextName'], 'Segmentation category and type - 3D Slicer General Anatomy list')
        self.assertEqual(terminology['category'], ['SCT', '123037004', 'Anatomical Structure'])
        self.assertEqual(terminology['type'], ['SCT', '39607008', 'Lung'])
        self.assertEqual(terminology['typeModifier'], ['SCT', '24028007', 'Right'])
        self.assertEqual(terminology['anatomicContextName'], 'Anatomic codes - DICOM master list')
        self.assertEqual('anatomicRegion' in terminology, False)
        self.assertEqual('anatomicRegionModifier' in terminology, False)

    def test_extract_segments(self):
        input_segmentation_filenames = ['Segmentation.seg.nrrd', 'SegmentationOverlapping.seg.nrrd']
        for input_segmentation_filename in input_segmentation_filenames:

            input_segmentation_filepath = slicerio.get_testdata_file(input_segmentation_filename)
            segmentation = slicerio.read_segmentation(input_segmentation_filepath)

            # Extract segments by name
            extracted_segmentation_by_name = slicerio.extract_segments(
                segmentation, [('ribs', 1), ('right lung', 3)]
            )

            # Verify pixel type of new segmentation
            self.assertEqual(extracted_segmentation_by_name["voxels"].dtype, segmentation["voxels"].dtype)

            # Verify that the extracted segmentation contains the requested label values
            # SegmentationOverlapping.seg.nrrd contains an additional segment overlapping with ribs and right lung,
            # but the sphere is not in the extracted segmentation, so it should not affect the extracted output.
            import numpy as np
            self.assertEqual(len(np.where(extracted_segmentation_by_name["voxels"] == 0)[0]), 514119) # background
            self.assertEqual(len(np.where(extracted_segmentation_by_name["voxels"] == 1)[0]), 8487) # ribs
            self.assertEqual(len(np.where(extracted_segmentation_by_name["voxels"] == 2)[0]), 0) # unused label
            self.assertEqual(len(np.where(extracted_segmentation_by_name["voxels"] == 3)[0]), 34450) # right lung
            self.assertEqual(len(np.where(extracted_segmentation_by_name["voxels"] == 4)[0]), 0) # unused label

            # Extract segments by terminology
            extracted_segmentation_by_terminology = slicerio.extract_segments(
                segmentation, [
                    # Note: intentionally using "ribs" instead of "Rib" (terminology value meaning) or "ribs" (segment name)
                    # to test that matching is based on terminology code value.
                    ({"category": ["SCT", "123037004", "Anatomical Structure"], "type": ["SCT", "113197003", "Ribs"]}, 1),
                    ({"category": ["SCT", "123037004", "Anatomical Structure"], "type": ["SCT", "39607008", "Lung"], "typeModifier": ["SCT", "24028007", "Right"]}, 3)
                    ])
            
            # Compare the two segmentations
            self._assert_segmentations_equal(extracted_segmentation_by_name, extracted_segmentation_by_terminology)

            # Verify that the extracted segmentation contains the requested label values
            import numpy as np
            self.assertEqual(len(np.where(extracted_segmentation_by_terminology["voxels"] == 0)[0]), 514119) # background
            self.assertEqual(len(np.where(extracted_segmentation_by_terminology["voxels"] == 1)[0]), 8487) # ribs
            self.assertEqual(len(np.where(extracted_segmentation_by_terminology["voxels"] == 2)[0]), 0) # unused label
            self.assertEqual(len(np.where(extracted_segmentation_by_terminology["voxels"] == 3)[0]), 34450) # right lung
            self.assertEqual(len(np.where(extracted_segmentation_by_terminology["voxels"] == 4)[0]), 0) # unused label

    def test_segmentation_write(self):
        import numpy as np
        import tempfile

        input_segmentation_filepath = slicerio.get_testdata_file('Segmentation.seg.nrrd')
        segmentation = slicerio.read_segmentation(input_segmentation_filepath)
        
        # Get a temporary filename
        output_segmentation_filepath = tempfile.mktemp() + '.seg.nrrd'

        # Write and re-read the segmentation
        slicerio.write_segmentation(output_segmentation_filepath, segmentation)
        segmentation_stored = slicerio.read_segmentation(output_segmentation_filepath)

        # Compare the two segmentations
        self._assert_segmentations_equal(segmentation, segmentation_stored)

        # Clean up temporary file
        import os
        os.remove(output_segmentation_filepath)

    def test_segmentation_create(self):
        import numpy as np
        import tempfile

        # Create segmentation with two rectangular prisms, with label values 1 and 3
        voxels = np.zeros([100, 120, 150], dtype=np.uint8)
        voxels[30:50, 20:60, 70:100] = 1
        voxels[70:90, 80:110, 60:110] = 3

        segmentation = {
            "voxels": voxels,
            "ijkToLPS": [
                [ 0.5, 0., 0., 10],
                [ 0., 0.5, 0., 30],
                [ 0., 0., 0.8, 15],
                [ 0., 0., 0., 1. ]],
            "containedRepresentationNames": ["Binary labelmap", "Closed surface"],
            "segments": [
                {
                    "labelValue": 1,
                    "terminology": {
                        "contextName": "Segmentation category and type - 3D Slicer General Anatomy list",
                        "category": ["SCT", "123037004", "Anatomical Structure"],
                        "type": ["SCT", "10200004", "Liver"] }
                },
                {
                    "labelValue": 3,
                    "terminology": {
                        "contextName": "Segmentation category and type - 3D Slicer General Anatomy list",
                        "category": ["SCT", "123037004", "Anatomical Structure"],
                        "type": ["SCT", "39607008", "Lung"],
                        "typeModifier": ["SCT", "24028007", "Right"] }
                },
            ]
        }

        # Get a temporary filename
        output_segmentation_filepath = tempfile.mktemp() + '.seg.nrrd'

        # Write and re-read the segmentation
        slicerio.write_segmentation(output_segmentation_filepath, segmentation)
        segmentation_stored = slicerio.read_segmentation(output_segmentation_filepath)

        # Compare the two segmentations
        self._assert_segmentations_equal(segmentation, segmentation_stored)

        # Clean up temporary file
        import os
        os.remove(output_segmentation_filepath)

    def _assert_segmentations_equal(self, segmentation1, segmentation2):
        """Compare segmentation1 to segmentation2.
        Ignores segment attributes that are present in segmentation2 but not in segmentation1.
        This is useful because when a segmentation is written to file then some extra attributes
        may be added to the segments (such as extent).
        """
        import numpy as np
        for key in segmentation1:
            if type(segmentation1[key]) == np.ndarray or type(segmentation2[key]) == np.ndarray:
                self.assertTrue(np.allclose(segmentation1[key], segmentation2[key]), f"Failed for key {key}")
            elif key == "segments":
                for segment1, segment2 in zip(segmentation1[key], segmentation2[key]):
                    for segmentAttribute in segment1:
                        self.assertEqual(segment1[segmentAttribute], segment2[segmentAttribute], f"Failed for key {key}[{segmentAttribute}]")
            else:
                equal = (segmentation1[key] == segmentation2[key])
                if type(equal) == list:
                    self.assertTrue(all(equal), f"Failed for key {key}")
                else:
                    self.assertTrue(equal, f"Failed for key {key}")

if __name__ == '__main__':
    unittest.main()
