#!/usr/bin/env python

import __init__

import os
import tempfile
import unittest

import rsvg
import cairo

import env, gui

class IconTest(unittest.TestCase):

    def setUp(self):
        self.svg_file = _get_file('stroke_color "#000000"',
                                  'fill_color "#000000"')

    def tearDown(self):
        os.unlink(self.svg_file)
        gui.icon_cache_reset()

    def testDefault(self):
        attr = gui.IconAttr(
                file_name=self.svg_file)
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#000000"', 'fill_color "#000000"',
                              50, 50)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

    def testSize(self):
        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10)
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#000000"', 'fill_color "#000000"',
                              10, 50)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10,
                height=10)
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#000000"', 'fill_color "#000000"',
                              10, 10)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

    def testColor(self):
        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10,
                height=10,
                stroke_color=env.Color(1, 1, 2, 3))
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#010203"', 'fill_color "#000000"',
                              10, 10)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10,
                height=10,
                fill_color=env.Color(1, 4, 5, 6))
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#000000"', 'fill_color "#040506"',
                              10, 10)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10,
                height=10,
                stroke_color=env.Color(1, 1, 2, 3),
                fill_color=env.Color(1, 4, 5, 6))
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('stroke_color "#010203"', 'fill_color "#040506"',
                              10, 10)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

        attr = gui.IconAttr(
                file_name=self.svg_file,
                width=10,
                height=10,
                stroke_color=env.Color(1, 1, 2, 3),
                fill_color=env.Color(1, 4, 5, 6))
        surface = gui.icon_get_surface(attr, False)
        sample = _get_surface('fill_color "#040506"', 'stroke_color "#010203"',
                              10, 10)
        self.assertEqual([i for i in surface.get_data()],
                         [i for i in sample.get_data()])

    def testCache(self):
        file1 = _get_file('stroke_color "#100000"', 'fill_color "#100000"')
        file2 = _get_file('stroke_color "#200000"', 'fill_color "#200000"')
        file3 = _get_file('stroke_color "#300000"', 'fill_color "#300000"')

        attrs = [{'file_name': file1 },

                 {'file_name': file1,
                  'width': 10},

                 {'file_name': file1,
                  'fill_color': env.Color(1, 4, 5, 6)},

                 {'file_name': file1,
                  'width': 20,
                  'height': 20,
                  'stroke_color': env.Color(1, 1, 2, 3),
                  'fill_color': env.Color(1, 4, 5, 6)},

                 {'file_name': file2,
                  'width': 30,
                  'height': 40,
                  'stroke_color': env.Color(1, 1, 2, 3),
                  'fill_color': env.Color(1, 4, 5, 6)}]

        for i in attrs:
            self.assertFalse(
                    gui.icon_get_surface(gui.IconAttr(**i), True) is None)

        attr = gui.IconAttr(file3)
        uncached_surface = gui.icon_get_surface(attr, False)

        os.unlink(file1)
        os.unlink(file2)
        os.unlink(file3)

        self.assertTrue(gui.icon_get_surface(attr, False) is None)

        for i in attrs:
            attr = gui.IconAttr(**i)
            self.assertFalse(gui.icon_get_surface(attr, False) is None)
            self.assertFalse(gui.icon_get_surface(attr, True) is None)


def _get_file(stroke_color, fill_color):
    file_name = tempfile.mktemp()
    svg_file = file(file_name, 'w')
    svg = template.replace('@STROKE@', stroke_color) \
                  .replace('@FILL@', fill_color)
    svg_file.write(svg)
    svg_file.close()
    return file_name


def _get_surface(stroke_color, fill_color, width, height):
    file_name = _get_file(stroke_color, fill_color)
    try:
        surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, width, height)
        context = cairo.Context(surface)
        context.scale(width / 50., height / 50.)
        handle = rsvg.Handle(file_name)
        handle.render_cairo(context)
        return surface
    finally:
        os.unlink(file_name)


template = """<?xml version="1.0" ?>
<!DOCTYPE svg  PUBLIC '-//W3C//DTD SVG 1.1//EN'
    'http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd' [
    <!ENTITY @STROKE@>
    <!ENTITY @FILL@>
]>
<svg height="50px" version="1.1" width="50px" x="0px" y="0px">
    <g display="block" id="cell-format">
        <rect display="inline" fill="&fill_color;" height="30"
        stroke="&stroke_color;" stroke-width="2.25" width="30" x="10" y="10"/>
    </g>
</svg>
"""

if __name__ == '__main__':
    unittest.main()
