diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 7a002faebb..3a3b0f9c04 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -813,6 +813,8 @@ class DatabricksConnectionConfig(ConnectionConfig): DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 + shared_connection: t.ClassVar[bool] = True + _concurrent_tasks_validator = concurrent_tasks_validator _http_headers_validator = http_headers_validator diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 2ff95525f7..4a77aea43b 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1422,6 +1422,35 @@ def test_databricks(make_config): ) +def test_databricks_shared_connection(make_config): + """Databricks should use a shared connection pool to prevent OAuth CSRF races. + + When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per + thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow; + these race on the CSRF state parameter and cause MismatchingStateError. + + Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be + used instead: a single connection is created (behind a lock) and each thread + gets its own cursor, so only one OAuth flow is ever initiated. + + See: https://github.com/tobymao/sqlmesh/issues/5646 + """ + from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + access_token="test-token", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is True + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) + + def test_engine_import_validator(): with pytest.raises( ConfigError,