from typing import Optional
from urllib import parse

from .base_client import DatabaseClient
from .database_credentials import DatabaseCredentials

import sqlalchemy
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.dialects.mssql import base as mssql_base


class MSSQLGeography(sqlalchemy.types.UserDefinedType):
    def get_col_spec(self, **kw):
        return "GEOGRAPHY"


mssql_base.ischema_names["geography"] = MSSQLGeography


class SqlalchemyClient(DatabaseClient):
    def __init__(self, credentials: DatabaseCredentials, database: str, driver: Optional[str] = None, **kwargs):
        super().__init__(credentials, database, **kwargs)
        self.driver = driver or self._get_best_driver()
        self._connection_string = self._build_connection_string()
        self._engine = sqlalchemy.create_engine(self._connection_string, **self.kwargs)

        self._base = automap_base()
        self._base.prepare(autoload_with=self._engine)
        self._tables = self._base.classes

    @property
    def engine(self) -> sqlalchemy.engine.Engine:
        return self._engine

    @property
    def tables(self) -> list[str]:
        return self._tables

    def _build_connection_string(self) -> str:
        """Build the SQLAlchemy connection string."""
        username = parse.quote(self.credentials.username)
        password = parse.quote(self.credentials.password)
        host = parse.quote(self.credentials.host)

        return f"mssql+pyodbc://{username}:{password}@{host}/{self.database}?driver={self.driver}&TrustServerCertificate=yes"

    def get_connection(self) -> sqlalchemy.engine.base.Connection:
        return self._engine.connect()

    def execute_query(
        self, query: str, parameters: list | None = None, return_results: bool = False
    ) -> list[sqlalchemy.Row] | None:
        with self.get_connection() as conn:
            if parameters:
                result = conn.execute(sqlalchemy.text(query), parameters)
            else:
                result = conn.execute(sqlalchemy.text(query))

            if return_results:
                return result.fetchall()

            conn.commit()
