# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest
from haystack import Pipeline, component
from haystack.components.embedders import OpenAITextEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils.auth import Secret

from burr.core import State, action
from burr.core.application import ApplicationBuilder
from burr.core.graph import GraphBuilder
from burr.integrations.haystack import HaystackAction, haystack_pipeline_to_burr_graph


@component
class MockComponent:
    def __init__(self, required_init: str, optional_init: str = "default"):
        self.required_init = required_init
        self.optional_init = optional_init

    @component.output_types(output_1=str, output_2=str)
    def run(self, required_input: str, optional_input: str = "default") -> dict:
        return {
            "output_1": required_input,
            "output_2": optional_input,
        }


@component
class MockComponentWithWarmup:
    def __init__(self, required_init: str, optional_init: str = "default"):
        self.required_init = required_init
        self.optional_init = optional_init
        self.is_warm = False

    def warm_up(self):
        self.is_warm = True

    @component.output_types(output_1=str, output_2=str)
    def run(self, required_input: str, optional_input: str = "default") -> dict:
        if self.is_warm is False:
            raise RuntimeError("You must call ``warm_up()`` before running.")

        return {
            "output_1": required_input,
            "output_2": optional_input,
        }


@action(reads=["query_embedding"], writes=["documents"])
def retrieve_documents(state: State) -> State:
    query_embedding = state["query_embedding"]

    document_store = InMemoryDocumentStore()
    retriever = InMemoryEmbeddingRetriever(document_store)

    results = retriever.run(query_embedding=query_embedding)
    return state.update(documents=results["documents"])


def test_input_socket_mapping():
    # {input_socket_name: state_field}
    reads = {"required_input": "foo"}

    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=reads, writes=[]
    )

    assert haction.reads == list(set(reads.values())) == ["foo"]


def test_input_socket_sequence():
    # {input_socket_name: input_socket_name}
    reads = ["required_input"]

    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=reads, writes=[]
    )

    assert haction.reads == list(reads) == ["required_input"]


def test_output_socket_mapping():
    # {state_field: output_socket_name}
    writes = {"bar": "output_1"}

    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=[], writes=writes
    )

    assert haction.writes == list(writes.keys()) == ["bar"]


def test_output_socket_sequence():
    # {output_socket_name: output_socket_name}
    writes = ["output_1"]

    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=[], writes=writes
    )

    assert haction.writes == writes == ["output_1"]


def test_get_component_source():
    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=[], writes=[]
    )

    expected_source = """\
@component
class MockComponent:
    def __init__(self, required_init: str, optional_init: str = "default"):
        self.required_init = required_init
        self.optional_init = optional_init

    @component.output_types(output_1=str, output_2=str)
    def run(self, required_input: str, optional_input: str = "default") -> dict:
        return {
            "output_1": required_input,
            "output_2": optional_input,
        }
"""

    assert haction.get_source() == expected_source


def test_run_with_external_inputs():
    state = State(initial_values={})
    haction = HaystackAction(
        component=MockComponent(required_init="init"), name="mock", reads=[], writes=[]
    )

    results = haction.run(state=state, required_input="as_input")

    assert results == {"output_1": "as_input", "output_2": "default"}


def test_run_with_state_inputs():
    state = State(initial_values={"foo": "bar"})
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads={"required_input": "foo"},
        writes=[],
    )

    results = haction.run(state=state)

    assert results == {"output_1": "bar", "output_2": "default"}


def test_run_with_bound_params():
    state = State(initial_values={})
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads=[],
        writes=[],
        bound_params={"required_input": "baz"},
    )

    results = haction.run(state=state)

    assert results == {"output_1": "baz", "output_2": "default"}


def test_run_mixed_params():
    state = State(initial_values={"foo": "bar"})
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads={"required_input": "foo"},
        writes=[],
        bound_params={"optional_input": "baz"},
    )

    results = haction.run(state=state)

    assert results == {"output_1": "bar", "output_2": "baz"}


def test_run_with_sequence():
    state = State(initial_values={"required_input": "bar"})
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads=["required_input"],
        writes=[],
    )

    results = haction.run(state=state)

    assert results == {"output_1": "bar", "output_2": "default"}


def test_update_with_writes_mapping():
    state = State(initial_values={})
    results = {"output_1": 1, "output_2": 2}
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads=[],
        writes={"foo": "output_1"},
    )

    new_state = haction.update(result=results, state=state)

    assert new_state["foo"] == 1


def test_update_with_writes_sequence():
    state = State(initial_values={})
    results = {"output_1": 1, "output_2": 2}
    haction = HaystackAction(
        component=MockComponent(required_init="init"),
        name="mock",
        reads=[],
        writes=["output_1"],
    )

    new_state = haction.update(result=results, state=state)

    assert new_state["output_1"] == 1


def test_component_is_warmed_up():
    state = State(initial_values={})
    haction = HaystackAction(
        component=MockComponentWithWarmup(required_init="init"),
        name="mock",
        reads=[],
        writes=[],
        do_warm_up=True,
    )
    results = haction.run(state=state, required_input="as_input")
    assert results == {"output_1": "as_input", "output_2": "default"}


def test_component_is_not_warmed_up():
    state = State(initial_values={})
    haction = HaystackAction(
        component=MockComponentWithWarmup(required_init="init"),
        name="mock",
        reads=[],
        writes=[],
        do_warm_up=False,
    )
    with pytest.raises(RuntimeError):
        haction.run(state=state, required_input="as_input")


def test_pipeline_converter():
    # create haystack Pipeline
    retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
    text_embedder = OpenAITextEmbedder(
        model="text-embedding-3-small", api_key=Secret.from_token("mock-key")
    )

    basic_rag_pipeline = Pipeline()
    basic_rag_pipeline.add_component("text_embedder", text_embedder)
    basic_rag_pipeline.add_component("retriever", retriever)
    basic_rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")

    # create Burr application
    embed_text = HaystackAction(
        component=text_embedder,
        name="text_embedder",
        reads=[],
        writes={"query_embedding": "embedding"},
    )

    retrieve_documents = HaystackAction(
        component=retriever,
        name="retriever",
        reads=["query_embedding"],
        writes=["documents"],
    )

    burr_graph = (
        GraphBuilder()
        .with_actions(embed_text, retrieve_documents)
        .with_transitions(("text_embedder", "retriever"))
        .build()
    )

    # convert the Haystack Pipeline to a Burr graph
    haystack_graph = haystack_pipeline_to_burr_graph(basic_rag_pipeline)

    converted_action_names = [action.name for action in haystack_graph.actions]
    for graph_action in burr_graph.actions:
        assert graph_action.name in converted_action_names

    for burr_t in burr_graph.transitions:
        assert any(
            burr_t.from_.name == haystack_t.from_.name and burr_t.to.name == haystack_t.to.name
            for haystack_t in haystack_graph.transitions
        )


def test_run_application():
    app = (
        ApplicationBuilder()
        .with_actions(
            HaystackAction(
                component=MockComponent(required_init="init"),
                name="mock",
                reads=[],
                writes=["output_1"],
            )
        )
        .with_transitions()
        .with_entrypoint("mock")
        .build()
    )

    _, _, state = app.run(halt_after=["mock"], inputs={"required_input": "runtime"})
    assert state["output_1"] == "runtime"


def test_run_application_is_warm_up():
    app = (
        ApplicationBuilder()
        .with_actions(
            HaystackAction(
                component=MockComponentWithWarmup(required_init="init"),
                name="mock",
                reads=[],
                writes=["output_1"],
            )
        )
        .with_transitions()
        .with_entrypoint("mock")
        .build()
    )

    _, _, state = app.run(halt_after=["mock"], inputs={"required_input": "runtime"})
    assert state["output_1"] == "runtime"


def test_run_application_is_not_warmed_up():
    app = (
        ApplicationBuilder()
        .with_actions(
            HaystackAction(
                component=MockComponentWithWarmup(required_init="init"),
                name="mock",
                reads=[],
                writes=["output_1"],
                do_warm_up=False,
            )
        )
        .with_transitions()
        .with_entrypoint("mock")
        .build()
    )
    with pytest.raises(RuntimeError):
        app.run(halt_after=["mock"], inputs={"required_input": "runtime"})
