# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Unit tests for scheduler tags backend functions."""

import ast
from unittest import TestCase

import pyparsing

from debusine.utils.tag_utils import (
    DerivationRuleParser,
    DerivationRules,
    ProvenanceRestrictions,
)


class ProvenanceRestrictionsTests(TestCase):
    """Tests for :py:class:`ProvenanceRestrictions`."""

    def test_name(self) -> None:
        r = ProvenanceRestrictions("foo")
        self.assertEqual(r.name, "foo")

    def test_add_exact(self) -> None:
        r = ProvenanceRestrictions("foo")
        r.add_exact("foo", ["bar"])
        r.add_exact("wibble", {"wobble"})
        self.assertEqual(
            r.exact,
            {
                "foo": {"bar"},
                "wibble": {"wobble"},
            },
        )
        self.assertEqual(r.prefixes, [])

    def test_add_exact_twice(self) -> None:
        r = ProvenanceRestrictions("prov")
        r.add_exact("foo", ["bar"])
        with self.assertRaisesRegex(
            ValueError,
            r"prov: exact match for tag foo specified multiple times",
        ):
            r.add_exact("foo", ["bar"])

    def test_add_prefix(self) -> None:
        r = ProvenanceRestrictions("foo")
        r.add_prefix("foo", ("bar",))
        r.add_prefix("wibble", ["wobble"])
        self.assertEqual(r.exact, {})
        self.assertEqual(
            r.prefixes,
            [
                ("foo", {"bar"}),
                ("wibble", {"wobble"}),
            ],
        )

    def test_filter_set(self) -> None:
        r = ProvenanceRestrictions("test")
        r.add_prefix("group", ["system"])
        r.add_exact("official:foo", ["system", "scope:foo"])
        r.add_exact("official:bar", ["system", "scope:bar"])
        r.add_prefix("official:bar", ["scope:bar"])
        for provenance, tags, expected in (
            ("foo", set(), True),
            ("foo", {"foo", "bar"}, True),
            ("system", {"group:foo", "group:bar", "baz"}, True),
            ("other", {"group:foo", "group:bar", "baz"}, {"baz"}),
            (
                "system",
                {"group:foo", "official:foo", "official:bar"},
                {"group:foo", "official:foo"},
            ),
            (
                "scope:foo",
                {"group:foo", "official:foo", "official:bar"},
                {"official:foo"},
            ),
            (
                "scope:bar",
                {"group:foo", "official:foo", "official:bar"},
                {"official:bar"},
            ),
        ):
            with self.subTest(provenance=provenance, tags=repr(tags)):
                if expected is True:
                    expected = tags
                self.assertEqual(r.filter_set(provenance, tags), expected)


class DerivationRuleParserTests(TestCase):
    """Tests for :py:class:`DerivationRuleParser`."""

    def test_tokens_invalid_string(self) -> None:
        for value, pos in (
            ("/", 0),
            (" /", 1),
            ("tag@tag", 3),
            ("[tag]", 0),
            ("tåg", 1),
        ):
            with (
                self.subTest(value=value),
                self.assertRaisesRegex(
                    pyparsing.ParseException,
                    rf"Expected .*, found '.*'  \(at char {pos}\)",
                ),
            ):
                DerivationRuleParser(value).as_expr()

    def test_as_expr(self) -> None:
        for value, parsed in (
            ("tag", "'tag' in tags"),
            ("not tag", "not 'tag' in tags"),
            ("(tag)", "'tag' in tags"),
            # See https://github.com/pyparsing/pyparsing/issues/204
            ("(((((((((tag)))))))))", "'tag' in tags"),
            ("   (  tag  )   ", "'tag' in tags"),
            ("tag1 and tag2", "'tag1' in tags and 'tag2' in tags"),
            (
                "tag1 and tag2 and tag3",
                "'tag1' in tags and 'tag2' in tags and ('tag3' in tags)",
            ),
            ("tag1 or tag2", "'tag1' in tags or 'tag2' in tags"),
            (
                "tag1 or tag2 or tag3",
                "'tag1' in tags or 'tag2' in tags or 'tag3' in tags",
            ),
            (
                "tag1 and tag2 or tag3",
                "'tag1' in tags and 'tag2' in tags or 'tag3' in tags",
            ),
            (
                "tag and (tag1 or tag2)",
                "'tag' in tags and ('tag1' in tags or 'tag2' in tags)",
            ),
            (
                "(tag1 and not tag2) or tag3",
                "'tag1' in tags and (not 'tag2' in tags) or 'tag3' in tags",
            ),
            (
                "tag1 or tag2 and tag3",
                "'tag1' in tags or ('tag2' in tags and 'tag3' in tags)",
            ),
            ("not tag1 and tag2", "not 'tag1' in tags and 'tag2' in tags"),
            (
                "not not not (((tag1))) or tag3",
                "not not not 'tag1' in tags or 'tag3' in tags",
            ),
            (
                "tag1 and not tag2 and tag3",
                "'tag1' in tags and (not 'tag2' in tags) and ('tag3' in tags)",
            ),
            (
                "tag1 and (tag2 or tag3)",
                "'tag1' in tags and ('tag2' in tags or 'tag3' in tags)",
            ),
            ("fooandbar", "'fooandbar' in tags"),
            ("fooorbar", "'fooorbar' in tags"),
        ):
            with self.subTest(value=value):
                parser = DerivationRuleParser(value)
                expr = parser.as_expr()
                self.assertEqual(ast.unparse(expr), parsed)

    def test_as_expr_invalid_syntax(self) -> None:
        for expr, message in (
            ("foo(", r"Expected end of text, found '\('"),
            ("foo bar", "Expected end of text, found 'bar'"),
            ("foo and", "Expected end of text, found 'and'"),
            ("or bar", r"Expected 'or' (term|operations), found 'or'"),
            ("()", r"Expected 'or' (term|operations), found '\)'"),
            ("(or)", "Expected 'or' (term|operations), found 'or'"),
            ("(and)", "Expected 'or' (term|operations), found 'and'"),
            ("foo.__globals__", r"Expected end of text, found '\.'"),
            ("foo + bar", r"Expected end of text, found '\+'"),
            ("' or \"", "Expected 'or' (term|operations), found \"'\""),
        ):
            with (
                self.subTest(expr=expr),
                self.assertRaisesRegex(pyparsing.ParseException, message),
            ):
                parser = DerivationRuleParser(expr)
                parser.as_expr()

    def test_as_function(self) -> None:
        for value, tags, result in (
            ("foo", {"foo", "bar"}, True),
            ("foo", {"bar"}, False),
            ("foo and bar", {"foo", "bar"}, True),
            ("foo and bar", {"foo"}, False),
            ("foo or bar", {"foo", "bar"}, True),
            ("foo or bar", {"foo"}, True),
            ("not foo", {"bar"}, True),
            ("not foo", {"foo"}, False),
        ):
            parser = DerivationRuleParser(value)
            f = parser.as_function("test")
            self.assertEqual(f(tags), result)


class DerivationRulesTests(TestCase):
    """Tests for :py:class:`DerivationRules`."""

    def test_provenance(self) -> None:
        r = DerivationRules(provenance="foo")
        self.assertEqual(r.provenance, "foo")

    def test_compute(self) -> None:
        r = DerivationRules(provenance="test")
        r.add_rule("dd or dm", {"debian"})
        r.add_rule("(dd or dm) and dsa", {"embargoed"})
        r.add_rule("debian", {"official"})
        for tags, expected in (
            (set(), set()),
            ({"foo"}, set()),
            ({"dd"}, {"debian"}),
            ({"dm"}, {"debian"}),
            ({"dd", "dsa"}, {"debian", "embargoed"}),
            ({"dm", "dsa"}, {"debian", "embargoed"}),
            ({"foo", "dsa"}, set()),
            ({"foo", "debian"}, {"official"}),
        ):
            with self.subTest(tags=repr(tags)):
                self.assertEqual(r.compute(tags), expected)
