# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""[Preview] Live API client."""

import asyncio
import base64
import contextlib
import json
import logging
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
import warnings

import google.auth
import pydantic
from websockets import ConnectionClosed

from . import _api_module
from . import _common
from . import _transformers as t
from . import client
from . import types
from ._api_client import BaseApiClient
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
from . import live_converters
from .models import _Content_to_mldev
from .models import _Content_to_vertex


try:
  from websockets.asyncio.client import ClientConnection  # type: ignore
  from websockets.asyncio.client import connect  # type: ignore
except ModuleNotFoundError:
  # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
  from websockets.client import ClientConnection  # type: ignore
  from websockets.client import connect  # type: ignore

logger = logging.getLogger('google_genai.live')

_FUNCTION_RESPONSE_REQUIRES_ID = (
    'FunctionResponse request must have an `id` field from the'
    ' response of a ToolCall.FunctionalCalls in Google AI.'
)


class AsyncSession:
  """[Preview] AsyncSession."""

  def __init__(
      self, api_client: client.BaseApiClient, websocket: ClientConnection
  ):
    self._api_client = api_client
    self._ws = websocket

  async def send(
      self,
      *,
      input: Optional[
          Union[
              types.ContentListUnion,
              types.ContentListUnionDict,
              types.LiveClientContentOrDict,
              types.LiveClientRealtimeInputOrDict,
              types.LiveClientToolResponseOrDict,
              types.FunctionResponseOrDict,
              Sequence[types.FunctionResponseOrDict],
          ]
      ] = None,
      end_of_turn: Optional[bool] = False,
  ):
    """[Deprecated] Send input to the model.

    > **Warning**: This method is deprecated and will be removed in a future
    version (not before Q3 2025). Please use one of the more specific methods:
    `send_client_content`, `send_realtime_input`, or `send_tool_response`
    instead.

    The method will send the input request to the server.

    Args:
      input: The input request to the model.
      end_of_turn: Whether the input is the last message in a turn.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)

      async with client.aio.live.connect(model='...') as session:
        await session.send(input='Hello world!', end_of_turn=True)
        async for message in session.receive():
          print(message)
    """
    warnings.warn(
        'The `session.send` method is deprecated and will be removed in a '
        'future version (not before Q3 2025).\n'
        'Please use one of the more specific methods: `send_client_content`, '
        '`send_realtime_input`, or `send_tool_response` instead.',
        DeprecationWarning,
        stacklevel=2,
    )
    client_message = self._parse_client_message(input, end_of_turn)
    await self._ws.send(json.dumps(client_message))

  async def send_client_content(
      self,
      *,
      turns: Optional[
          Union[
              types.Content,
              types.ContentDict,
              list[Union[types.Content, types.ContentDict]]
          ]
      ] = None,
      turn_complete: bool = True,
  ):
    """Send non-realtime, turn based content to the model.

    There are two ways to send messages to the live API:
    `send_client_content` and `send_realtime_input`.

    `send_client_content` messages are added to the model context **in order**.
    Having a conversation using `send_client_content` messages is roughly
    equivalent to using the `Chat.send_message_stream` method, except that the
    state of the `chat` history is stored on the API server.

    Because of `send_client_content`'s order guarantee, the model cannot
    respond as quickly to `send_client_content` messages as to
    `send_realtime_input` messages. This makes the biggest difference when
    sending objects that have significant preprocessing time (typically images).

    The `send_client_content` message sends a list of `Content` objects,
    which has more options than the `media:Blob` sent by `send_realtime_input`.

    The main use-cases for `send_client_content` over `send_realtime_input` are:

    - Prefilling a conversation context (including sending anything that can't
      be represented as a realtime message), before starting a realtime
      conversation.
    - Conducting a non-realtime conversation, similar to `client.chat`, using
      the live api.

    Caution: Interleaving `send_client_content` and `send_realtime_input`
      in the same conversation is not recommended and can lead to unexpected
      results.

    Args:
      turns: A `Content` object or list of `Content` objects (or equivalent
        dicts).
      turn_complete: if true (the default) the model will reply immediately. If
        false, the model will wait for you to send additional client_content,
        and will not return until you send `turn_complete=True`.

    Example:
    ```
    import google.genai
    from google.genai import types
    import os

    if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
      MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
    else:
      MODEL_NAME = 'gemini-2.0-flash-live-001';

    client = genai.Client()
    async with client.aio.live.connect(
        model=MODEL_NAME,
        config={"response_modalities": ["TEXT"]}
    ) as session:
      await session.send_client_content(
          turns=types.Content(
              role='user',
              parts=[types.Part(text="Hello world!")]))
      async for msg in session.receive():
        if msg.text:
          print(msg.text)
    ```
    """
    client_content = t.t_client_content(turns, turn_complete)

    if self._api_client.vertexai:
      client_content_dict = live_converters._LiveClientContent_to_vertex(
          api_client=self._api_client, from_object=client_content
      )
    else:
      client_content_dict = live_converters._LiveClientContent_to_mldev(
          api_client=self._api_client, from_object=client_content
      )

    await self._ws.send(json.dumps({'client_content': client_content_dict}))

  async def send_realtime_input(self, *, media: t.BlobUnion):
    """Send realtime media chunks to the model.

    Use `send_realtime_input` for realtime audio chunks and video
    frames(images).

    With `send_realtime_input` the api will respond to audio automatically
    based on voice activity detection (VAD).

    `send_realtime_input` is optimized for responsivness at the expense of
    deterministic ordering. Audio and video tokens are added to the
    context when they become available.

    Args:
      media: A `Blob`-like object, the realtime media to send.

    Example:
    ```
    from pathlib import Path

    from google import genai
    from google.genai import types

    import PIL.Image
    
    import os

    if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
      MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
    else:
      MODEL_NAME = 'gemini-2.0-flash-live-001';


    client = genai.Client()

    async with client.aio.live.connect(
        model=MODEL_NAME,
        config={"response_modalities": ["TEXT"]},
    ) as session:
      await session.send_realtime_input(
          media=PIL.Image.open('image.jpg'))

      audio_bytes = Path('audio.pcm').read_bytes()
      await session.send_realtime_input(
          media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))

      async for msg in session.receive():
        if msg.text is not None:
          print(f'{msg.text}')
    ```
    """
    realtime_input = t.t_realtime_input(media)
    realtime_input_dict = realtime_input.model_dump(
        exclude_none=True, mode='json'
    )
    await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))

  async def send_tool_response(
      self,
      *,
      function_responses: Union[
          types.FunctionResponseOrDict,
          Sequence[types.FunctionResponseOrDict],
      ],
  ):
    """Send a tool response to the session.

    Use `send_tool_response` to reply to `LiveServerToolCall` messages
    from the server.

    To set the available tools, use the `config.tools` argument
    when you connect to the session (`client.live.connect`).

    Args:
      function_responses: A `FunctionResponse`-like object or list of
        `FunctionResponse`-like objects.

    Example:
    ```
    from google import genai
    from google.genai import types

    import os

    if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
      MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
    else:
      MODEL_NAME = 'gemini-2.0-flash-live-001';

    client = genai.Client()

    tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
    config = {
        "tools": tools,
        "response_modalities": ['TEXT']
    }

    async with client.aio.live.connect(
        model='models/gemini-2.0-flash-live-001',
        config=config
    ) as session:
      prompt = "Turn on the lights please"
      await session.send_client_content(
          turns={"parts": [{'text': prompt}]}
      )

      async for chunk in session.receive():
          if chunk.server_content:
            if chunk.text is not None:
              print(chunk.text)
          elif chunk.tool_call:
            print(chunk.tool_call)
            print('_'*80)
            function_response=types.FunctionResponse(
                    name='turn_on_the_lights',
                    response={'result': 'ok'},
                    id=chunk.tool_call.function_calls[0].id,
                )
            print(function_response)
            await session.send_tool_response(
                function_responses=function_response
            )

            print('_'*80)
    """
    tool_response = t.t_tool_response(function_responses)
    if self._api_client.vertexai:
      tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
          api_client=self._api_client, from_object=tool_response
      )
    else:
      tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
          api_client=self._api_client, from_object=tool_response
      )
    await self._ws.send(json.dumps({'tool_response': tool_response_dict}))

  async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
    """Receive model responses from the server.

    The method will yield the model responses from the server. The returned
    responses will represent a complete model turn. When the returned message
    is function call, user must call `send` with the function response to
    continue the turn.

    Yields:
      The model responses from the server.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)

      async with client.aio.live.connect(model='...') as session:
        await session.send(input='Hello world!', end_of_turn=True)
        async for message in session.receive():
          print(message)
    """
    # TODO(b/365983264) Handle intermittent issues for the user.
    while result := await self._receive():
      if result.server_content and result.server_content.turn_complete:
        yield result
        break
      yield result

  async def start_stream(
      self, *, stream: AsyncIterator[bytes], mime_type: str
  ) -> AsyncIterator[types.LiveServerMessage]:
    """[Deprecated] Start a live session from a data stream.

    > **Warning**: This method is deprecated and will be removed in a future
    version (not before Q2 2025). Please use one of the more specific methods:
    `send_client_content`, `send_realtime_input`, or `send_tool_response`
    instead.

    The interaction terminates when the input stream is complete.
    This method will start two async tasks. One task will be used to send the
    input stream to the model and the other task will be used to receive the
    responses from the model.

    Args:
      stream: An iterator that yields the model response.
      mime_type: The MIME type of the data in the stream.

    Yields:
      The audio bytes received from the model and server response messages.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)
      config = {'response_modalities': ['AUDIO']}
      async def audio_stream():
        stream = read_audio()
        for data in stream:
          yield data
      async with client.aio.live.connect(model='...', config=config) as session:
        for audio in session.start_stream(stream = audio_stream(),
        mime_type = 'audio/pcm'):
          play_audio_chunk(audio.data)
    """
    warnings.warn(
        'Setting `AsyncSession.start_stream` is deprecated, '
        'and will be removed in a future release (not before Q3 2025). '
        'Please use the `receive`, and `send_realtime_input`, methods instead.',
        DeprecationWarning,
        stacklevel=4,
    )
    stop_event = asyncio.Event()
    # Start the send loop. When stream is complete stop_event is set.
    asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
    recv_task = None
    while not stop_event.is_set():
      try:
        recv_task = asyncio.create_task(self._receive())
        await asyncio.wait(
            [
                recv_task,
                asyncio.create_task(stop_event.wait()),
            ],
            return_when=asyncio.FIRST_COMPLETED,
        )
        if recv_task.done():
          yield recv_task.result()
          # Give a chance for the send loop to process requests.
          await asyncio.sleep(10**-12)
      except ConnectionClosed:
        break
    if recv_task is not None and not recv_task.done():
      recv_task.cancel()
      # Wait for the task to finish (cancelled or not)
      try:
        await recv_task
      except asyncio.CancelledError:
        pass

  async def _receive(self) -> types.LiveServerMessage:
    parameter_model = types.LiveServerMessage()
    try:
      raw_response = await self._ws.recv(decode=False)
    except TypeError:
      raw_response = await self._ws.recv()  # type: ignore[assignment]
    if raw_response:
      try:
        response = json.loads(raw_response)
      except json.decoder.JSONDecodeError:
        raise ValueError(f'Failed to parse response: {raw_response!r}')
    else:
      response = {}

    if self._api_client.vertexai:
      response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
    else:
      response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)

    return types.LiveServerMessage._from_response(
        response=response_dict, kwargs=parameter_model.model_dump()
    )

  async def _send_loop(
      self,
      data_stream: AsyncIterator[bytes],
      mime_type: str,
      stop_event: asyncio.Event,
  ):
    async for data in data_stream:
      model_input = types.LiveClientRealtimeInput(
        media_chunks=[types.Blob(data=data, mime_type=mime_type)]
      )
      await self.send(input=model_input)
      # Give a chance for the receive loop to process responses.
      await asyncio.sleep(10**-12)
    # Give a chance for the receiver to process the last response.
    stop_event.set()

  def _parse_client_message(
      self,
      input: Optional[
          Union[
              types.ContentListUnion,
              types.ContentListUnionDict,
              types.LiveClientContentOrDict,
              types.LiveClientRealtimeInputOrDict,
              types.LiveClientToolResponseOrDict,
              types.FunctionResponseOrDict,
              Sequence[types.FunctionResponseOrDict],
          ]
      ] = None,
      end_of_turn: Optional[bool] = False,
  ) -> types.LiveClientMessageDict:

    formatted_input: Any = input

    if not input:
      logging.info('No input provided. Assume it is the end of turn.')
      return {'client_content': {'turn_complete': True}}
    if isinstance(input, str):
      formatted_input = [input]
    elif isinstance(input, dict) and 'data' in input:
      try:
        blob_input = types.Blob(**input)
      except pydantic.ValidationError:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )
      if (
          isinstance(blob_input, types.Blob)
          and isinstance(blob_input.data, bytes)
      ):
        formatted_input = [
            blob_input.model_dump(mode='json', exclude_none=True)
        ]
    elif isinstance(input, types.Blob):
      formatted_input = [input]
    elif isinstance(input, dict) and 'name' in input and 'response' in input:
      # ToolResponse.FunctionResponse
      if not (self._api_client.vertexai) and 'id' not in input:
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      formatted_input = [input]

    if isinstance(formatted_input, Sequence) and any(
        isinstance(c, dict) and 'name' in c and 'response' in c
        for c in formatted_input
    ):
      # ToolResponse.FunctionResponse
      function_responses_input = []
      for item in formatted_input:
        if isinstance(item, dict):
          try:
            function_response_input = types.FunctionResponse(**item)
          except pydantic.ValidationError:
            raise ValueError(
                f'Unsupported input type "{type(input)}" or input content'
                f' "{input}"'
            )
          if (
              function_response_input.id is None
              and not self._api_client.vertexai
          ):
            raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
          else:
            function_response_dict = function_response_input.model_dump(
                exclude_none=True, mode='json'
            )
            function_response_typeddict = types.FunctionResponseDict(
                name=function_response_dict.get('name'),
                response=function_response_dict.get('response'),
            )
            if function_response_dict.get('id'):
              function_response_typeddict['id'] = function_response_dict.get(
                  'id'
              )
            function_responses_input.append(function_response_typeddict)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=function_responses_input
          )
      )
    elif isinstance(formatted_input, Sequence) and any(
        isinstance(c, str) for c in formatted_input
    ):
      to_object: dict[str, Any] = {}
      content_input_parts: list[types.PartUnion] = []
      for item in formatted_input:
        if isinstance(item, get_args(types.PartUnion)):
          content_input_parts.append(item)
      if self._api_client.vertexai:
        contents = [
            _Content_to_vertex(self._api_client, item, to_object)
            for item in t.t_contents(self._api_client, content_input_parts)
        ]
      else:
        contents = [
            _Content_to_mldev(self._api_client, item, to_object)
            for item in t.t_contents(self._api_client, content_input_parts)
        ]

      content_dict_list: list[types.ContentDict] = []
      for item in contents:
        try:
          content_input = types.Content(**item)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        content_dict_list.append(
            types.ContentDict(
                parts=content_input.model_dump(exclude_none=True, mode='json')[
                    'parts'
                ],
                role=content_input.role,
            )
        )

      client_message = types.LiveClientMessageDict(
          client_content=types.LiveClientContentDict(
              turns=content_dict_list, turn_complete=end_of_turn
          )
      )
    elif isinstance(formatted_input, Sequence):
      if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
        pass
      elif any(isinstance(b, types.Blob) for b in formatted_input):
        formatted_input = [
            b.model_dump(exclude_none=True, mode='json')
            for b in formatted_input
        ]
      else:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )

      client_message = types.LiveClientMessageDict(
          realtime_input=types.LiveClientRealtimeInputDict(
              media_chunks=formatted_input
          )
      )

    elif isinstance(formatted_input, dict):
      if 'content' in formatted_input or 'turns' in formatted_input:
        # TODO(b/365983264) Add validation checks for content_update input_dict.
        if 'turns' in formatted_input:
          content_turns = formatted_input['turns']
        else:
          content_turns = formatted_input['content']
        client_message = types.LiveClientMessageDict(
            client_content=types.LiveClientContentDict(
                turns=content_turns,
                turn_complete=formatted_input.get('turn_complete'),
            )
        )
      elif 'media_chunks' in formatted_input:
        try:
          realtime_input = types.LiveClientRealtimeInput(**formatted_input)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        client_message = types.LiveClientMessageDict(
            realtime_input=types.LiveClientRealtimeInputDict(
                media_chunks=realtime_input.model_dump(
                    exclude_none=True, mode='json'
                )['media_chunks']
            )
        )
      elif 'function_responses' in formatted_input:
        try:
          tool_response_input = types.LiveClientToolResponse(**formatted_input)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        client_message = types.LiveClientMessageDict(
            tool_response=types.LiveClientToolResponseDict(
                function_responses=tool_response_input.model_dump(
                    exclude_none=True, mode='json'
                )['function_responses']
            )
        )
      else:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )
    elif isinstance(formatted_input, types.LiveClientRealtimeInput):
      realtime_input_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      client_message = types.LiveClientMessageDict(
          realtime_input=types.LiveClientRealtimeInputDict(
              media_chunks=realtime_input_dict.get('media_chunks')
          )
      )
      if (
          client_message['realtime_input'] is not None
          and client_message['realtime_input']['media_chunks'] is not None
          and isinstance(
              client_message['realtime_input']['media_chunks'][0]['data'], bytes
          )
      ):
        formatted_media_chunks: list[types.BlobDict] = []
        for item in client_message['realtime_input']['media_chunks']:
          if isinstance(item, dict):
            try:
              blob_input = types.Blob(**item)
            except pydantic.ValidationError:
              raise ValueError(
                  f'Unsupported input type "{type(input)}" or input content'
                  f' "{input}"'
              )
            if (
                isinstance(blob_input, types.Blob)
                and isinstance(blob_input.data, bytes)
                and blob_input.data is not None
            ):
              formatted_media_chunks.append(
                  types.BlobDict(
                      data=base64.b64decode(blob_input.data),
                      mime_type=blob_input.mime_type,
                  )
              )

        client_message['realtime_input'][
            'media_chunks'
        ] = formatted_media_chunks

    elif isinstance(formatted_input, types.LiveClientContent):
      client_content_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      client_message = types.LiveClientMessageDict(
          client_content=types.LiveClientContentDict(
              turns=client_content_dict.get('turns'),
              turn_complete=client_content_dict.get('turn_complete'),
          )
      )
    elif isinstance(formatted_input, types.LiveClientToolResponse):
      # ToolResponse.FunctionResponse
      if (
          not (self._api_client.vertexai)
          and formatted_input.function_responses is not None
          and not (formatted_input.function_responses[0].id)
      ):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=formatted_input.model_dump(
                  exclude_none=True, mode='json'
              ).get('function_responses')
          )
      )
    elif isinstance(formatted_input, types.FunctionResponse):
      if not (self._api_client.vertexai) and not (formatted_input.id):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      function_response_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      function_response_typeddict = types.FunctionResponseDict(
          name=function_response_dict.get('name'),
          response=function_response_dict.get('response'),
      )
      if function_response_dict.get('id'):
        function_response_typeddict['id'] = function_response_dict.get('id')
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=[function_response_typeddict]
          )
      )
    elif isinstance(formatted_input, Sequence) and isinstance(
        formatted_input[0], types.FunctionResponse
    ):
      if not (self._api_client.vertexai) and not (formatted_input[0].id):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      function_response_list: list[types.FunctionResponseDict] = []
      for item in formatted_input:
        function_response_dict = item.model_dump(exclude_none=True, mode='json')
        function_response_typeddict = types.FunctionResponseDict(
            name=function_response_dict.get('name'),
            response=function_response_dict.get('response'),
        )
        if function_response_dict.get('id'):
          function_response_typeddict['id'] = function_response_dict.get('id')
        function_response_list.append(function_response_typeddict)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=function_response_list
          )
      )

    else:
      raise ValueError(
          f'Unsupported input type "{type(input)}" or input content "{input}"'
      )

    return client_message

  async def close(self) -> None:
    # Close the websocket connection.
    await self._ws.close()



class AsyncLive(_api_module.BaseModule):
  """[Preview] AsyncLive."""


  @contextlib.asynccontextmanager
  async def connect(
      self,
      *,
      model: str,
      config: Optional[types.LiveConnectConfigOrDict] = None,
  ) -> AsyncIterator[AsyncSession]:
    """[Preview] Connect to the live server.

    Note: the live API is currently in preview.

    Usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)
      config = {}
      async with client.aio.live.connect(model='...', config=config) as session:
        await session.send(input='Hello world!', end_of_turn=True)
        async for message in session.receive():
          print(message)
    """
    base_url = self._api_client._websocket_base_url()
    transformed_model = t.t_model(self._api_client, model)

    parameter_model = _t_live_connect_config(self._api_client, config)

    if self._api_client.api_key:
      api_key = self._api_client.api_key
      version = self._api_client._http_options.api_version
      uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
      headers = self._api_client._http_options.headers

      request_dict = _common.convert_to_dict(
          live_converters._LiveConnectParameters_to_mldev(
              api_client=self._api_client,
              from_object=types.LiveConnectParameters(
                model=transformed_model,
                config=parameter_model,
              ).model_dump(exclude_none=True)
          )
      )
      del request_dict['config']

      setv(request_dict, ['setup', 'model'], transformed_model)

      request = json.dumps(request_dict)
    else:
      # Get bearer token through Application Default Credentials.
      creds, _ = google.auth.default(
          scopes=['https://www.googleapis.com/auth/cloud-platform']
      )

      # creds.valid is False, and creds.token is None
      # Need to refresh credentials to populate those
      auth_req = google.auth.transport.requests.Request()
      creds.refresh(auth_req)
      bearer_token = creds.token
      headers = self._api_client._http_options.headers
      if headers is not None:
        headers.update({
            'Authorization': 'Bearer {}'.format(bearer_token),
        })
      version = self._api_client._http_options.api_version
      uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
      location = self._api_client.location
      project = self._api_client.project
      if transformed_model.startswith('publishers/'):
        transformed_model = (
            f'projects/{project}/locations/{location}/' + transformed_model
        )
      request_dict = _common.convert_to_dict(
          live_converters._LiveConnectParameters_to_vertex(
              api_client=self._api_client,
              from_object=types.LiveConnectParameters(
                model=transformed_model,
                config=parameter_model,
              ).model_dump(exclude_none=True)
          )
      )
      del request_dict['config']

      if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
        setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])

      request = json.dumps(request_dict)

    try:
      async with connect(uri, additional_headers=headers) as ws:
        await ws.send(request)
        logger.info(await ws.recv(decode=False))

        yield AsyncSession(api_client=self._api_client, websocket=ws)
    except TypeError:
      # Try with the older websockets API
      async with connect(uri, extra_headers=headers) as ws:
        await ws.send(request)
        logger.info(await ws.recv())

        yield AsyncSession(api_client=self._api_client, websocket=ws)


def _t_live_connect_config(
    api_client: BaseApiClient,
    config: Optional[types.LiveConnectConfigOrDict],
) -> types.LiveConnectConfig:
  # Ensure the config is a LiveConnectConfig.
  if config is None:
    parameter_model = types.LiveConnectConfig()
  elif isinstance(config, dict):
    if getv(config, ['system_instruction']) is not None:
      converted_system_instruction = t.t_content(
          api_client, getv(config, ['system_instruction'])
      )
    else:
      converted_system_instruction = None
    parameter_model = types.LiveConnectConfig(**config)
    parameter_model.system_instruction = converted_system_instruction
  else:
    if config.system_instruction is None:
      system_instruction = None
    else:
      system_instruction = t.t_content(
          api_client, getv(config, ['system_instruction'])
      )
    parameter_model = config
    parameter_model.system_instruction = system_instruction

  if parameter_model.generation_config is not None:
    warnings.warn(
        'Setting `LiveConnectConfig.generation_config` is deprecated, '
        'please set the fields on `LiveConnectConfig` directly. This will '
        'become an error in a future version (not before Q3 2025)',
        DeprecationWarning,
        stacklevel=4,
    )

  return parameter_model
