11 examples
Deadlock
Processes permanently stuck waiting on each other, halting execution.
[ FAQ1 ]
What is a deadlock?
A deadlock happens when multiple threads or processes each hold resources (like mutex locks or database records) and simultaneously wait for resources held by others, creating a cyclic waiting scenario. Since none of the involved processes or threads can proceed, the application or database becomes stalled or frozen indefinitely. In concurrent programming, deadlocks are often caused by incorrect lock ordering, insufficient resource management, or poorly structured transactions. Deadlocks degrade system performance and reliability, leading to potential application downtime and data integrity issues.
[ FAQ2 ]
How to fix deadlocks
To fix deadlocks, follow structured locking protocols by acquiring resources in a consistent order across all threads or transactions. Limit lock scope and duration by holding locks only as long as necessary, minimizing contention. In database systems, carefully manage transaction isolation levels and use timeout settings to detect and handle deadlocks proactively. Implement deadlock detection and recovery mechanisms provided by databases or concurrency libraries to automatically resolve deadlocks when they occur. Regularly testing for concurrency scenarios, analyzing logs, and leveraging specialized tools help identify potential deadlocks early, ensuring smoother, safer concurrent operations.
diff block
}
// Add delay between runs if too many sessions
- if WS_SESSIONS.len() > 1000 {
+ let r = WS_SESSIONS.read().await;
+ if r.len() > 1000 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
+ drop(r);
}
});
}
async fn cleanup_expired_sessions() {
- let expired: Vec<String> = WS_SESSIONS
- .iter()
- .filter(|entry| entry.value().is_expired())
- .map(|entry| entry.key().clone())
- .collect();
+ // get expired sessions
+ let r = WS_SESSIONS.read().await;
+ let session_ids: Vec<String> = r.keys().cloned().collect();
+ let mut expired = Vec::new();
+
+ for session_id in session_ids {
+ if let Some(session) = r.get(&session_id) {
+ let session_lock = session.lock().await;
+ if session_lock.is_expired() {
+ expired.push(session_id);
+ }
+ drop(session_lock);
+ }
+ }
greptile
logic: Holding the read lock while iterating through sessions and acquiring session locks could lead to deadlocks if a session operation tries to acquire the write lock. Consider collecting all sessions first, dropping the read lock, then checking expiration.
diff block
+import { ErrCode } from "@/errors/errCodes.js";
+import { QueueManager } from "@/queue/QueueManager.js";
+import RecaseError from "@/utils/errorUtils.js";
+
+export const handleAttachRaceCondition = async ({
+ req,
+ res,
+}: {
+ req: any;
+ res: any;
+}) => {
+ const redisConn = await QueueManager.getConnection({ useBackup: false });
+ const customerId = req.body.customer_id;
+ const orgId = req.orgId;
+ const env = req.env;
+ try {
+ const lockKey = `attach_${customerId}_${orgId}_${env}`;
+ const existingLock = await redisConn.get(lockKey);
+ if (existingLock) {
+ throw new RecaseError({
+ message: `Attach already runnning for customer ${customerId}, try again in a few seconds`,
+ code: ErrCode.InvalidRequest,
+ statusCode: 400,
+ });
+ }
+ // Create lock with 5 second timeout
+ await redisConn.set(lockKey, "1", "PX", 5000, "NX");
+
+ let originalJson = res.json;
+ res.json = async function (body: any) {
+ if (lockKey) {
+ await clearLock({ lockKey, logger: req.logtail });
+ }
+ originalJson.call(this, body);
+ };
+
+ // return lockKey;
+ } catch (error) {
+ if (error instanceof RecaseError) {
+ throw error;
+ }
+
+ req.logtail.warn("❗️❗️ Error acquiring lock");
+ req.logtail.warn(error);
+ return null;
+ }
+};
+
+export const clearLock = async ({
+ lockKey,
+ logger,
+}: {
+ lockKey: string;
+ logger: any;
+}) => {
+ try {
+ const redisConn = await QueueManager.getConnection({ useBackup: false });
+ await redisConn.del(lockKey);
+ } catch (error) {
+ logger.warn("❗️❗️ Error clearing lock");
+ logger.warn(error);
+ }
greptile
logic: Lock clear failures should probably be escalated since they could lead to deadlocks. Consider throwing after logging.
diff block
+// Copyright 2025 OpenObserve Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+use std::collections::{HashMap, HashSet};
+
+use config::{RwAHashMap, ider};
+
+use super::{
+ error::*,
+ handler::{ClientId, QuerierName, SessionId, TraceId},
+};
+
+#[derive(Debug, Default)]
+pub struct SessionManager {
+ sessions: RwAHashMap<ClientId, SessionInfo>,
+ mapped_queriers: RwAHashMap<QuerierName, Vec<TraceId>>,
+}
+
+#[derive(Debug, Clone)]
+pub struct SessionInfo {
+ pub session_id: SessionId,
+ pub querier_mappings: HashMap<TraceId, QuerierName>,
+ pub created_at: chrono::DateTime<chrono::Utc>,
+ pub last_active: chrono::DateTime<chrono::Utc>,
+}
+
+impl SessionManager {
+ pub async fn register_client(&self, client_id: &ClientId) {
+ if self.sessions.read().await.get(client_id).is_some() {
+ self.update_session_activity(client_id).await;
+ return;
+ }
+
+ let now = chrono::Utc::now();
+ let session_info = SessionInfo {
+ session_id: ider::uuid(),
+ querier_mappings: HashMap::default(),
+ created_at: now,
+ last_active: now,
+ };
+
+ let mut write_guard = self.sessions.write().await;
+ if !write_guard.contains_key(client_id) {
+ write_guard.insert(client_id.clone(), session_info.clone());
+ return;
+ }
+ }
+
+ pub async fn update_session_activity(&self, client_id: &ClientId) {
+ let mut write_guard = self.sessions.write().await;
+ if let Some(session_info) = write_guard.get_mut(client_id) {
+ session_info.last_active = chrono::Utc::now();
+ }
+ }
+
+ pub async fn unregister_client(&self, client_id: &ClientId) {
+ if let Some(session_info) = self.sessions.write().await.remove(client_id) {
+ let mut mapped_querier_write = self.mapped_queriers.write().await;
+
+ for (trace_id, querier_name) in session_info.querier_mappings {
+ if let Some(trace_ids) = mapped_querier_write.get_mut(&querier_name) {
+ trace_ids.retain(|tid| tid != &trace_id);
+ }
+ }
+ }
+ }
+
+ pub async fn remove_querier_connection(&self, querier_name: &QuerierName) {
+ let client_ids = {
+ let (mapped_read, sessions_read) =
+ tokio::join!(self.mapped_queriers.read(), self.sessions.read());
+
+ match mapped_read.get(querier_name) {
+ Some(_) => sessions_read.keys().cloned().collect::<Vec<_>>(),
+ None => return,
+ }
+ };
+
+ // Remove from mapped_querier
+ let trace_ids = self
+ .mapped_queriers
+ .write()
+ .await
+ .remove(querier_name)
+ .map(|ids| ids.into_iter().collect::<HashSet<_>>())
+ .unwrap(); // existence validated
+
+ // Batch update sessions
+ let mut session_write = self.sessions.write().await;
+ for client_id in client_ids {
+ if let Some(session_info) = session_write.get_mut(&client_id) {
+ session_info
+ .querier_mappings
+ .retain(|tid, _| !trace_ids.contains(tid));
+ }
+ }
+ }
+
+ pub async fn set_querier_for_trace(
+ &self,
+ client_id: &ClientId,
+ trace_id: &TraceId,
+ querier_name: &QuerierName,
+ ) -> WsResult<()> {
+ // sessions
+ // let mut write_guard = self.sessions.write().await;
+ self.sessions
+ .write()
+ .await
+ .get_mut(client_id)
+ .ok_or(WsError::SessionNotFound(format!(
+ "[WS::SessionManager]: client_id {} not found",
+ client_id
+ )))?
+ .querier_mappings
+ .insert(trace_id.clone(), querier_name.clone());
+
+ // mapped_queriers
+ self.mapped_queriers
+ .write()
+ .await
+ .entry(querier_name.clone())
+ .or_insert_with(|| Vec::new())
+ .push(trace_id.clone());
+ Ok(())
+ }
greptile
style: Two separate write locks could deadlock if another thread acquires them in reverse order. Consider combining into a single atomic operation or establish a consistent lock ordering.
diff block
-import type { Channel } from 'storybook/internal/channels';
-import {
- TESTING_MODULE_CANCEL_TEST_RUN_REQUEST,
- TESTING_MODULE_PROGRESS_REPORT,
- TESTING_MODULE_RUN_REQUEST,
- type TestingModuleCancelTestRunRequestPayload,
- type TestingModuleProgressReportPayload,
- type TestingModuleRunRequestPayload,
-} from 'storybook/internal/core-events';
+import type { TestResult, TestState } from 'vitest/dist/node.js';
+
import type { experimental_UniversalStore } from 'storybook/internal/core-server';
+import type {
+ StatusStoreByTypeId,
+ StatusValue,
+ TestProviderStoreById,
+} from 'storybook/internal/types';
-import { isEqual } from 'es-toolkit';
+import { throttle } from 'es-toolkit';
+import type { Report } from 'storybook/preview-api';
-import { type StoreState, TEST_PROVIDER_ID } from '../constants';
+import { STATUS_TYPE_ID_A11Y, STATUS_TYPE_ID_COMPONENT_TEST, storeOptions } from '../constants';
+import type { RunTrigger, StoreEvent, StoreState, TriggerRunEvent, VitestError } from '../types';
+import { errorToErrorLike } from '../utils';
import { VitestManager } from './vitest-manager';
+export type TestManagerOptions = {
+ store: experimental_UniversalStore<StoreState, StoreEvent>;
+ componentTestStatusStore: StatusStoreByTypeId;
+ a11yStatusStore: StatusStoreByTypeId;
+ testProviderStore: TestProviderStoreById;
+ onError?: (message: string, error: Error) => void;
+ onReady?: () => void;
+};
+
+const testStateToStatusValueMap: Record<TestState | 'warning', StatusValue> = {
+ pending: 'status-value:pending',
+ passed: 'status-value:success',
+ warning: 'status-value:warning',
+ failed: 'status-value:error',
+ skipped: 'status-value:unknown',
+};
+
export class TestManager {
- vitestManager: VitestManager;
-
- selectedStoryCountForLastRun = 0;
-
- constructor(
- private channel: Channel,
- public store: experimental_UniversalStore<StoreState>,
- private options: {
- onError?: (message: string, error: Error) => void;
- onReady?: () => void;
- } = {}
- ) {
- this.vitestManager = new VitestManager(this);
+ public store: TestManagerOptions['store'];
- this.channel.on(TESTING_MODULE_RUN_REQUEST, this.handleRunRequest.bind(this));
- this.channel.on(TESTING_MODULE_CANCEL_TEST_RUN_REQUEST, this.handleCancelRequest.bind(this));
+ public vitestManager: VitestManager;
- this.store.onStateChange((state, previousState) => {
- if (!isEqual(state.config, previousState.config)) {
- this.handleConfigChange(state.config, previousState.config);
- }
- if (state.watching !== previousState.watching) {
- this.handleWatchModeRequest(state.watching);
- }
- });
+ private componentTestStatusStore: TestManagerOptions['componentTestStatusStore'];
- this.vitestManager.startVitest().then(() => options.onReady?.());
- }
+ private a11yStatusStore: TestManagerOptions['a11yStatusStore'];
- async handleConfigChange(config: StoreState['config'], previousConfig: StoreState['config']) {
- process.env.VITEST_STORYBOOK_CONFIG = JSON.stringify(config);
+ private testProviderStore: TestManagerOptions['testProviderStore'];
- if (config.coverage !== previousConfig.coverage) {
- try {
- await this.vitestManager.restartVitest({
- coverage: config.coverage,
- });
- } catch (e) {
- this.reportFatalError('Failed to change coverage configuration', e);
- }
- }
+ private onReady?: TestManagerOptions['onReady'];
+
+ private batchedTestCaseResults: {
+ storyId: string;
+ testResult: TestResult;
+ reports?: Report[];
+ }[] = [];
+
+ constructor(options: TestManagerOptions) {
+ this.store = options.store;
+ this.componentTestStatusStore = options.componentTestStatusStore;
+ this.a11yStatusStore = options.a11yStatusStore;
+ this.testProviderStore = options.testProviderStore;
+ this.onReady = options.onReady;
+
+ this.vitestManager = new VitestManager(this);
+
+ this.store.subscribe('TRIGGER_RUN', this.handleTriggerRunEvent.bind(this));
+ this.store.subscribe('CANCEL_RUN', this.handleCancelEvent.bind(this));
+
+ this.store
+ .untilReady()
+ .then(() =>
+ this.vitestManager.startVitest({ coverage: this.store.getState().config.coverage })
+ )
+ .then(() => this.onReady?.())
+ .catch((e) => {
+ this.reportFatalError('Failed to start Vitest', e);
+ });
}
- async handleWatchModeRequest(watching: boolean) {
- const coverage = this.store.getState().config.coverage ?? false;
-
- if (coverage) {
- try {
- if (watching) {
- // if watch mode is toggled on and coverage is already enabled, restart vitest without coverage to automatically disable it
- await this.vitestManager.restartVitest({ coverage: false });
- } else {
- // if watch mode is toggled off and coverage is already enabled, restart vitest with coverage to automatically re-enable it
- await this.vitestManager.restartVitest({ coverage });
+ async handleTriggerRunEvent(event: TriggerRunEvent) {
+ await this.runTestsWithState({
+ storyIds: event.payload.storyIds,
+ triggeredBy: event.payload.triggeredBy,
+ callback: async () => {
+ try {
+ await this.vitestManager.vitestRestartPromise;
+ await this.vitestManager.runTests(event.payload);
+ } catch (err) {
+ this.reportFatalError('Failed to run tests', err);
+ throw err;
}
- } catch (e) {
- this.reportFatalError('Failed to change watch mode while coverage was enabled', e);
- }
- }
+ },
+ });
}
- async handleRunRequest(payload: TestingModuleRunRequestPayload) {
+ async handleCancelEvent() {
try {
- if (payload.providerId !== TEST_PROVIDER_ID) {
- return;
- }
-
- const state = this.store.getState();
-
- /*
- If we're only running a subset of stories, we have to temporarily disable coverage,
- as a coverage report for a subset of stories is not useful.
- */
- const temporarilyDisableCoverage =
- state.config.coverage && !state.watching && (payload.storyIds ?? []).length > 0;
- if (temporarilyDisableCoverage) {
- await this.vitestManager.restartVitest({
- coverage: false,
- });
- } else {
- await this.vitestManager.vitestRestartPromise;
- }
+ this.store.setState((s) => ({
+ ...s,
+ cancelling: true,
+ }));
+ await this.vitestManager.cancelCurrentRun();
+ } catch (err) {
+ this.reportFatalError('Failed to cancel tests', err);
+ } finally {
+ this.store.setState((s) => ({
+ ...s,
+ cancelling: false,
+ }));
+ }
+ }
- this.selectedStoryCountForLastRun = payload.storyIds?.length ?? 0;
+ async runTestsWithState({
+ storyIds,
+ triggeredBy,
+ callback,
+ }: {
+ storyIds?: string[];
+ triggeredBy: RunTrigger;
+ callback: () => Promise<void>;
+ }) {
+ this.componentTestStatusStore.unset(storyIds);
+ this.a11yStatusStore.unset(storyIds);
- await this.vitestManager.runTests(payload);
+ this.store.setState((s) => ({
+ ...s,
+ currentRun: {
+ ...storeOptions.initialState.currentRun,
+ triggeredBy,
+ startedAt: Date.now(),
+ storyIds: storyIds,
+ config: s.config,
+ },
+ }));
+ // set the config at the start of a test run,
+ // so that changing the config during the test run does not affect the currently running test run
+ process.env.VITEST_STORYBOOK_CONFIG = JSON.stringify(this.store.getState().config);
- if (temporarilyDisableCoverage) {
- // Re-enable coverage if it was temporarily disabled because of a subset of stories was run
- await this.vitestManager.restartVitest({ coverage: state?.config.coverage });
+ await this.testProviderStore.runWithState(async () => {
+ await callback();
+ this.store.send({
+ type: 'TEST_RUN_COMPLETED',
+ payload: this.store.getState().currentRun,
+ });
+ if (this.store.getState().currentRun.unhandledErrors.length > 0) {
+ throw new Error('Tests completed but there are unhandled errors');
}
- } catch (e) {
- this.reportFatalError('Failed to run tests', e);
- }
+ });
}
- async handleCancelRequest(payload: TestingModuleCancelTestRunRequestPayload) {
- try {
- if (payload.providerId !== TEST_PROVIDER_ID) {
- return;
- }
+ onTestModuleCollected(collectedTestCount: number) {
+ this.store.setState((s) => ({
+ ...s,
+ currentRun: {
+ ...s.currentRun,
+ totalTestCount: (s.currentRun.totalTestCount ?? 0) + collectedTestCount,
+ },
+ }));
+ }
- await this.vitestManager.cancelCurrentRun();
- } catch (e) {
- this.reportFatalError('Failed to cancel tests', e);
+ onTestCaseResult(result: { storyId?: string; testResult: TestResult; reports?: Report[] }) {
+ const { storyId, testResult, reports } = result;
+ if (!storyId) {
+ return;
}
+
+ this.batchedTestCaseResults.push({ storyId, testResult, reports });
+ this.throttledFlushTestCaseResults();
}
- async sendProgressReport(payload: TestingModuleProgressReportPayload) {
- this.channel.emit(TESTING_MODULE_PROGRESS_REPORT, {
- ...payload,
- details: { ...payload.details, selectedStoryCount: this.selectedStoryCountForLastRun },
+ /**
+ * Throttled function to process batched test case results.
+ *
+ * This function:
+ *
+ * 1. Takes all batched test case results and clears the batch
+ * 2. Updates the store state with new test counts (component tests and a11y tests)
+ * 3. Adjusts the totalTestCount if more tests were run than initially anticipated
+ * 4. Creates status objects for component tests and updates the component test status store
+ * 5. Creates status objects for a11y tests (if any) and updates the a11y status store
+ *
+ * The throttling (500ms) is necessary as the channel would otherwise get overwhelmed with events,
+ * eventually causing the manager and dev server to loose connection.
+ */
+ throttledFlushTestCaseResults = throttle(() => {
+ const testCaseResultsToFlush = this.batchedTestCaseResults;
+ this.batchedTestCaseResults = [];
+
+ this.store.setState((s) => {
+ let { success: ctSuccess, error: ctError } = s.currentRun.componentTestCount;
+ let { success: a11ySuccess, warning: a11yWarning, error: a11yError } = s.currentRun.a11yCount;
+ testCaseResultsToFlush.forEach(({ testResult, reports }) => {
+ if (testResult.state === 'passed') {
+ ctSuccess++;
+ } else if (testResult.state === 'failed') {
+ ctError++;
+ }
+ reports
+ ?.filter((r) => r.type === 'a11y')
+ .forEach((report) => {
+ if (report.status === 'passed') {
+ a11ySuccess++;
+ } else if (report.status === 'warning') {
+ a11yWarning++;
+ } else if (report.status === 'failed') {
+ a11yError++;
+ }
+ });
+ });
+ const finishedTestCount = ctSuccess + ctError;
+
+ return {
+ ...s,
+ currentRun: {
+ ...s.currentRun,
+ componentTestCount: { success: ctSuccess, error: ctError },
+ a11yCount: { success: a11ySuccess, warning: a11yWarning, error: a11yError },
+ // in some cases successes and errors can exceed the anticipated totalTestCount
+ // e.g. when testing more tests than the stories we know about upfront
+ // in those cases, we set the totalTestCount to the sum of successes and errors
+ totalTestCount:
+ finishedTestCount > (s.currentRun.totalTestCount ?? 0)
+ ? finishedTestCount
+ : s.currentRun.totalTestCount,
+ },
+ };
});
- const status = 'status' in payload ? payload.status : undefined;
- const progress = 'progress' in payload ? payload.progress : undefined;
- if (
- ((status === 'success' || status === 'cancelled') && progress?.finishedAt) ||
- status === 'failed'
- ) {
- // reset the count when a test run is fully finished
- this.selectedStoryCountForLastRun = 0;
+ const componentTestStatuses = testCaseResultsToFlush.map(({ storyId, testResult }) => ({
+ storyId,
+ typeId: STATUS_TYPE_ID_COMPONENT_TEST,
+ value: testStateToStatusValueMap[testResult.state],
+ title: 'Component tests',
+ description: testResult.errors?.map((error) => error.stack || error.message).join('\n') ?? '',
+ sidebarContextMenu: false,
+ }));
+
+ this.componentTestStatusStore.set(componentTestStatuses);
+
+ const a11yStatuses = testCaseResultsToFlush
+ .flatMap(({ storyId, reports }) =>
+ reports
+ ?.filter((r) => r.type === 'a11y')
+ .map((a11yReport) => ({
+ storyId,
+ typeId: STATUS_TYPE_ID_A11Y,
+ value: testStateToStatusValueMap[a11yReport.status],
+ title: 'Accessibility tests',
+ description: '',
+ sidebarContextMenu: false,
+ }))
+ )
+ .filter((a11yStatus) => a11yStatus !== undefined);
+
+ if (a11yStatuses.length > 0) {
+ this.a11yStatusStore.set(a11yStatuses);
}
+ }, 500);
+
+ onTestRunEnd(endResult: { totalTestCount: number; unhandledErrors: VitestError[] }) {
+ this.store.setState((s) => ({
+ ...s,
+ currentRun: {
+ ...s.currentRun,
+ // when the test run is finished, we can set the totalTestCount to the actual number of tests run
+ // this number can be lower than the total number of tests we anticipated upfront
+ // e.g. when some tests where skipped without us knowing about it upfront
+ totalTestCount: endResult.totalTestCount,
+ unhandledErrors: endResult.unhandledErrors,
+ finishedAt: Date.now(),
+ },
+ }));
+ }
+
+ onCoverageCollected(coverageSummary: StoreState['currentRun']['coverageSummary']) {
+ this.store.setState((s) => ({
+ ...s,
+ currentRun: { ...s.currentRun, coverageSummary },
+ }));
}
async reportFatalError(message: string, error: Error | any) {
- this.options.onError?.(message, error);
+ await this.store.untilReady();
+ this.store.send({
+ type: 'FATAL_ERROR',
+ payload: {
+ message,
+ error: errorToErrorLike(error),
+ },
+ });
}
greptile
logic: reportFatalError awaits store.untilReady() but is called in catch blocks - could deadlock if store never becomes ready
diff block
+import asyncio
+import dataclasses
+import datetime as dt
+import json
+import typing
+
+import temporalio.activity
+import temporalio.common
+import temporalio.workflow
+from asgiref.sync import sync_to_async
+
+from posthog.temporal.common.base import PostHogWorkflow
+from posthog.temporal.common.heartbeat import Heartbeater
+from posthog.temporal.common.logger import get_internal_logger
+from posthog.temporal.session_recordings.queries import get_sampled_session_ids
+from posthog.session_recordings.models.session_recording import SessionRecording
+from posthog.models import Team
+from posthog.temporal.session_recordings.session_comparer import (
+ get_url_from_event,
+ add_url,
+ count_events_per_window,
+ group_events_by_type,
+ is_click,
+ is_mouse_activity,
+ is_keypress,
+ is_console_log,
+ get_console_level,
+)
+from posthog.temporal.session_recordings.queries import get_session_metadata
+from posthog.temporal.session_recordings.snapshot_utils import fetch_v1_snapshots, fetch_v2_snapshots
+
+
+@dataclasses.dataclass(frozen=True)
+class CompareSampledRecordingEventsActivityInputs:
+ """Inputs for the recording events comparison activity."""
+
+ started_after: str = dataclasses.field()
+ started_before: str = dataclasses.field()
+ sample_size: int = dataclasses.field(default=100)
+
+ @property
+ def properties_to_log(self) -> dict[str, typing.Any]:
+ return {
+ "started_after": self.started_after,
+ "started_before": self.started_before,
+ "sample_size": self.sample_size,
+ }
+
+
+@temporalio.activity.defn
+async def compare_sampled_recording_events_activity(inputs: CompareSampledRecordingEventsActivityInputs) -> None:
+ """Compare recording events between v1 and v2 storage for a sample of sessions."""
+ logger = get_internal_logger()
+ start_time = dt.datetime.now()
+
+ await logger.ainfo(
+ "Starting sampled events comparison activity",
+ started_after=inputs.started_after,
+ started_before=inputs.started_before,
+ sample_size=inputs.sample_size,
+ )
+
+ async with Heartbeater():
+ started_after = dt.datetime.fromisoformat(inputs.started_after)
+ started_before = dt.datetime.fromisoformat(inputs.started_before)
+
+ # Get sample of session IDs
+ session_ids = await asyncio.to_thread(
+ get_sampled_session_ids,
+ started_after,
+ started_before,
+ inputs.sample_size,
+ )
+
+ for session_id, team_id in session_ids:
+ await logger.ainfo(
+ "Processing session",
+ session_id=session_id,
+ team_id=team_id,
+ )
+
+ team = await sync_to_async(Team.objects.get)(id=team_id)
+ recording = await sync_to_async(SessionRecording.get_or_build)(session_id=session_id, team=team)
+
+ # Get v1 and v2 snapshots using the shared utility functions
+ v1_snapshots = await asyncio.to_thread(fetch_v1_snapshots, recording)
+ v2_snapshots = await asyncio.to_thread(fetch_v2_snapshots, recording)
+
+ # Convert snapshots to dictionaries for counting duplicates
+ v1_events: dict[str, int] = {}
+ v2_events: dict[str, int] = {}
+
+ for s in v1_snapshots:
+ event_key = json.dumps((s["window_id"], s["data"]), sort_keys=True)
+ v1_events[event_key] = v1_events.get(event_key, 0) + 1
+
+ for s in v2_snapshots:
+ event_key = json.dumps((s["window_id"], s["data"]), sort_keys=True)
+ v2_events[event_key] = v2_events.get(event_key, 0) + 1
+
+ # Find events in both versions with their counts
+ all_keys = set(v1_events.keys()) | set(v2_events.keys())
+ common_events = {
+ k: (v1_events.get(k, 0), v2_events.get(k, 0)) for k in all_keys if k in v1_events and k in v2_events
+ }
+ only_in_v1 = {k: v1_events[k] for k in v1_events.keys() - v2_events.keys()}
+ only_in_v2 = {k: v2_events[k] for k in v2_events.keys() - v1_events.keys()}
+
+ # Get metadata counts
+ v1_metadata = get_session_metadata(team.pk, recording.session_id, "session_replay_events")
+ v2_metadata = get_session_metadata(team.pk, recording.session_id, "session_replay_events_v2_test")
+
+ # Track URLs for both versions
+ v1_urls: set[str] = set()
+ v1_first_url: str | None = None
+ v2_urls: set[str] = set()
+ v2_first_url: str | None = None
+
+ # Count events by type in v1
+ v1_click_count = 0
+ v1_mouse_activity_count = 0
+ v1_keypress_count = 0
+ v1_console_log_count = 0
+ v1_console_warn_count = 0
+ v1_console_error_count = 0
+
+ for snapshot in v1_snapshots:
+ data = snapshot["data"]
+ if is_click(data):
+ v1_click_count += 1
+ if is_mouse_activity(data):
+ v1_mouse_activity_count += 1
+ if is_keypress(data):
+ v1_keypress_count += 1
+ if is_console_log(data):
+ level = get_console_level(data)
+ if level in [
+ "log",
+ "info",
+ "count",
+ "timeEnd",
+ "trace",
+ "dir",
+ "dirxml",
+ "group",
+ "groupCollapsed",
+ "debug",
+ "timeLog",
+ ]:
+ v1_console_log_count += 1
+ elif level in ["warn", "countReset"]:
+ v1_console_warn_count += 1
+ elif level in ["error", "assert"]:
+ v1_console_error_count += 1
+ else: # default to log level for unknown levels
+ v1_console_log_count += 1
+
+ # Extract URL
+ url = get_url_from_event(data)
+ if url:
+ if v1_first_url is None:
+ v1_first_url = url[:4096] if len(url) > 4096 else url
+ add_url(v1_urls, url)
+
+ # Count events by type in v2
+ v2_click_count = 0
+ v2_mouse_activity_count = 0
+ v2_keypress_count = 0
+ v2_console_log_count = 0
+ v2_console_warn_count = 0
+ v2_console_error_count = 0
+
+ for snapshot in v2_snapshots:
+ data = snapshot["data"]
+ if is_click(data):
+ v2_click_count += 1
+ if is_mouse_activity(data):
+ v2_mouse_activity_count += 1
+ if is_keypress(data):
+ v2_keypress_count += 1
+ if is_console_log(data):
+ level = get_console_level(data)
+ if level in [
+ "log",
+ "info",
+ "count",
+ "timeEnd",
+ "trace",
+ "dir",
+ "dirxml",
+ "group",
+ "groupCollapsed",
+ "debug",
+ "timeLog",
+ ]:
+ v2_console_log_count += 1
+ elif level in ["warn", "countReset"]:
+ v2_console_warn_count += 1
+ elif level in ["error", "assert"]:
+ v2_console_error_count += 1
+ else: # default to log level for unknown levels
+ v2_console_log_count += 1
+
+ # Extract URL
+ url = get_url_from_event(data)
+ if url:
+ if v2_first_url is None:
+ v2_first_url = url[:4096] if len(url) > 4096 else url
+ add_url(v2_urls, url)
+
+ # Compare URLs
+ await logger.ainfo(
+ "URL comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_first_url=v1_first_url,
+ v2_first_url=v2_first_url,
+ first_url_matches=v1_first_url == v2_first_url,
+ v1_url_count=len(v1_urls),
+ v2_url_count=len(v2_urls),
+ urls_in_both=len(v1_urls & v2_urls),
+ only_in_v1=sorted(v1_urls - v2_urls)[:5], # Show up to 5 examples
+ only_in_v2=sorted(v2_urls - v1_urls)[:5], # Show up to 5 examples
+ metadata_comparison={
+ "v1": {
+ "first_url": v1_metadata["first_url"],
+ "all_urls": v1_metadata["all_urls"],
+ "first_url_matches_snapshot": v1_metadata["first_url"] == v1_first_url,
+ "all_urls_match_snapshot": set(v1_metadata["all_urls"]) == v1_urls,
+ },
+ "v2": {
+ "first_url": v2_metadata["first_url"],
+ "all_urls": v2_metadata["all_urls"],
+ "first_url_matches_snapshot": v2_metadata["first_url"] == v2_first_url,
+ "all_urls_match_snapshot": set(v2_metadata["all_urls"]) == v2_urls,
+ },
+ },
+ )
+
+ # Log event counts and differences
+ await logger.ainfo(
+ "Total event count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=len(v1_snapshots),
+ v2_snapshot_count=len(v2_snapshots),
+ v1_metadata_count=v1_metadata["event_count"],
+ v2_metadata_count=v2_metadata["event_count"],
+ snapshot_difference=len(v2_snapshots) - len(v1_snapshots),
+ metadata_difference=v2_metadata["event_count"] - v1_metadata["event_count"],
+ snapshot_vs_metadata_v1_difference=len(v1_snapshots) - v1_metadata["event_count"],
+ snapshot_vs_metadata_v2_difference=len(v2_snapshots) - v2_metadata["event_count"],
+ )
+
+ await logger.ainfo(
+ "Click count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_click_count,
+ v2_snapshot_count=v2_click_count,
+ v1_metadata_count=v1_metadata["click_count"],
+ v2_metadata_count=v2_metadata["click_count"],
+ snapshot_difference=v2_click_count - v1_click_count,
+ metadata_difference=v2_metadata["click_count"] - v1_metadata["click_count"],
+ snapshot_vs_metadata_v1_difference=v1_click_count - v1_metadata["click_count"],
+ snapshot_vs_metadata_v2_difference=v2_click_count - v2_metadata["click_count"],
+ )
+
+ await logger.ainfo(
+ "Mouse activity count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_mouse_activity_count,
+ v2_snapshot_count=v2_mouse_activity_count,
+ v1_metadata_count=v1_metadata["mouse_activity_count"],
+ v2_metadata_count=v2_metadata["mouse_activity_count"],
+ snapshot_difference=v2_mouse_activity_count - v1_mouse_activity_count,
+ metadata_difference=v2_metadata["mouse_activity_count"] - v1_metadata["mouse_activity_count"],
+ snapshot_vs_metadata_v1_difference=v1_mouse_activity_count - v1_metadata["mouse_activity_count"],
+ snapshot_vs_metadata_v2_difference=v2_mouse_activity_count - v2_metadata["mouse_activity_count"],
+ )
+
+ await logger.ainfo(
+ "Keypress count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_keypress_count,
+ v2_snapshot_count=v2_keypress_count,
+ v1_metadata_count=v1_metadata["keypress_count"],
+ v2_metadata_count=v2_metadata["keypress_count"],
+ snapshot_difference=v2_keypress_count - v1_keypress_count,
+ metadata_difference=v2_metadata["keypress_count"] - v1_metadata["keypress_count"],
+ snapshot_vs_metadata_v1_difference=v1_keypress_count - v1_metadata["keypress_count"],
+ snapshot_vs_metadata_v2_difference=v2_keypress_count - v2_metadata["keypress_count"],
+ )
+
+ await logger.ainfo(
+ "Console log count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_console_log_count,
+ v2_snapshot_count=v2_console_log_count,
+ v1_metadata_count=v1_metadata["console_log_count"],
+ v2_metadata_count=v2_metadata["console_log_count"],
+ snapshot_difference=v2_console_log_count - v1_console_log_count,
+ metadata_difference=v2_metadata["console_log_count"] - v1_metadata["console_log_count"],
+ snapshot_vs_metadata_v1_difference=v1_console_log_count - v1_metadata["console_log_count"],
+ snapshot_vs_metadata_v2_difference=v2_console_log_count - v2_metadata["console_log_count"],
+ )
+
+ await logger.ainfo(
+ "Console warn count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_console_warn_count,
+ v2_snapshot_count=v2_console_warn_count,
+ v1_metadata_count=v1_metadata["console_warn_count"],
+ v2_metadata_count=v2_metadata["console_warn_count"],
+ snapshot_difference=v2_console_warn_count - v1_console_warn_count,
+ metadata_difference=v2_metadata["console_warn_count"] - v1_metadata["console_warn_count"],
+ snapshot_vs_metadata_v1_difference=v1_console_warn_count - v1_metadata["console_warn_count"],
+ snapshot_vs_metadata_v2_difference=v2_console_warn_count - v2_metadata["console_warn_count"],
+ )
+
+ await logger.ainfo(
+ "Console error count comparison",
+ session_id=session_id,
+ team_id=team_id,
+ v1_snapshot_count=v1_console_error_count,
+ v2_snapshot_count=v2_console_error_count,
+ v1_metadata_count=v1_metadata["console_error_count"],
+ v2_metadata_count=v2_metadata["console_error_count"],
+ snapshot_difference=v2_console_error_count - v1_console_error_count,
+ metadata_difference=v2_metadata["console_error_count"] - v1_metadata["console_error_count"],
+ snapshot_vs_metadata_v1_difference=v1_console_error_count - v1_metadata["console_error_count"],
+ snapshot_vs_metadata_v2_difference=v2_console_error_count - v2_metadata["console_error_count"],
+ )
+
+ # Log event type comparison
+ await logger.ainfo(
+ "Event type comparison",
+ session_id=session_id,
+ team_id=team_id,
+ common_events_count=sum(min(v1, v2) for v1, v2 in common_events.values()),
+ common_events_by_type=group_events_by_type({k: min(v1, v2) for k, (v1, v2) in common_events.items()}),
+ only_in_v1_count=sum(only_in_v1.values()),
+ only_in_v1_by_type=group_events_by_type(only_in_v1),
+ only_in_v2_count=sum(only_in_v2.values()),
+ only_in_v2_by_type=group_events_by_type(only_in_v2),
+ duplicate_stats={
+ "v1_total_duplicates": sum(count - 1 for count in v1_events.values() if count > 1),
+ "v2_total_duplicates": sum(count - 1 for count in v2_events.values() if count > 1),
+ "events_with_different_counts": {k: (v1, v2) for k, (v1, v2) in common_events.items() if v1 != v2},
+ },
+ )
+
+ # Analyze events per window
+ v1_window_counts = count_events_per_window(v1_events)
+ v2_window_counts = count_events_per_window(v2_events)
+
+ # Find all window IDs
+ all_window_ids = set(v1_window_counts.keys()) | set(v2_window_counts.keys())
+ window_comparison = []
+ # Handle None first, then sort the rest
+ sorted_window_ids = ([None] if None in all_window_ids else []) + sorted(
+ id for id in all_window_ids if id is not None
+ )
+ for window_id in sorted_window_ids:
+ window_comparison.append(
+ {
+ "window_id": window_id,
+ "v1_events": v1_window_counts.get(window_id, 0),
+ "v2_events": v2_window_counts.get(window_id, 0),
+ }
+ )
+
+ await logger.ainfo(
+ "Events per window comparison",
+ session_id=session_id,
+ team_id=team_id,
+ window_counts=window_comparison,
+ total_windows=len(all_window_ids),
+ windows_in_v1=len(v1_window_counts),
+ windows_in_v2=len(v2_window_counts),
+ windows_in_both=len(set(v1_window_counts.keys()) & set(v2_window_counts.keys())),
+ )
+
+ # Check for differences in metadata vs snapshots
+ metadata_differences = any(
+ [
+ v1_metadata["click_count"] != v1_click_count,
+ v1_metadata["mouse_activity_count"] != v1_mouse_activity_count,
+ v1_metadata["keypress_count"] != v1_keypress_count,
+ v1_metadata["console_log_count"] != v1_console_log_count,
+ v1_metadata["console_warn_count"] != v1_console_warn_count,
+ v1_metadata["console_error_count"] != v1_console_error_count,
+ v2_metadata["click_count"] != v2_click_count,
+ v2_metadata["mouse_activity_count"] != v2_mouse_activity_count,
+ v2_metadata["keypress_count"] != v2_keypress_count,
+ v2_metadata["console_log_count"] != v2_console_log_count,
+ v2_metadata["console_warn_count"] != v2_console_warn_count,
+ v2_metadata["console_error_count"] != v2_console_error_count,
+ ]
+ )
+
+ # Check if sessions differ in any way
+ sessions_differ = any(
+ [
+ len(v1_snapshots) != len(v2_snapshots),
+ v1_click_count != v2_click_count,
+ v1_mouse_activity_count != v2_mouse_activity_count,
+ v1_keypress_count != v2_keypress_count,
+ v1_console_log_count != v2_console_log_count,
+ v1_console_warn_count != v2_console_warn_count,
+ v1_console_error_count != v2_console_error_count,
+ v1_urls != v2_urls,
+ v1_first_url != v2_first_url,
+ bool(only_in_v1),
+ bool(only_in_v2),
+ ]
+ )
+
+ # Log session summary
+ await logger.ainfo(
+ "Session comparison summary",
+ session_id=session_id,
+ team_id=team_id,
+ sessions_differ=sessions_differ,
+ metadata_snapshot_differences=metadata_differences,
+ v1_snapshot_count=len(v1_snapshots),
+ v2_snapshot_count=len(v2_snapshots),
+ )
+
+ end_time = dt.datetime.now()
+ duration = (end_time - start_time).total_seconds()
+
+ # Log activity summary
+ await logger.ainfo(
+ "Completed sampled events comparison activity",
+ duration_seconds=duration,
+ sessions_processed=len(session_ids),
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class CompareSampledRecordingEventsWorkflowInputs:
+ """Inputs for the recording events comparison workflow."""
+
+ started_after: str = dataclasses.field()
+ started_before: str = dataclasses.field()
+ window_seconds: int = dataclasses.field(default=300) # 5 minutes default
+ sample_size: int = dataclasses.field(default=100)
+
+ @property
+ def properties_to_log(self) -> dict[str, typing.Any]:
+ return {
+ "started_after": self.started_after,
+ "started_before": self.started_before,
+ "window_seconds": self.window_seconds,
+ "sample_size": self.sample_size,
+ }
+
+
+@temporalio.workflow.defn(name="compare-sampled-recording-events")
+class CompareSampledRecordingEventsWorkflow(PostHogWorkflow):
+ """Workflow to compare recording events between v1 and v2 for sampled sessions."""
+
+ def __init__(self) -> None:
+ self.lock = asyncio.Lock()
+ self.paused = False
+
+ @staticmethod
+ def parse_inputs(inputs: list[str]) -> CompareSampledRecordingEventsWorkflowInputs:
+ """Parse inputs from the management command CLI."""
+ loaded = json.loads(inputs[0])
+
+ for field in ["started_after", "started_before"]:
+ if field not in loaded:
+ raise ValueError(f"Required field {field} not provided")
+ loaded[field] = dt.datetime.fromisoformat(loaded[field])
+
+ window_seconds = loaded.get("window_seconds", 300)
+ if not isinstance(window_seconds, int) or window_seconds <= 0:
+ raise ValueError("window_seconds must be a positive integer")
+
+ sample_size = loaded.get("sample_size", 100)
+ if not isinstance(sample_size, int) or sample_size <= 0:
+ raise ValueError("sample_size must be a positive integer")
+
+ return CompareSampledRecordingEventsWorkflowInputs(
+ started_after=loaded["started_after"],
+ started_before=loaded["started_before"],
+ window_seconds=window_seconds,
+ sample_size=sample_size,
+ )
+
+ @staticmethod
+ def generate_time_windows(
+ start_time: dt.datetime, end_time: dt.datetime, window_seconds: int
+ ) -> list[tuple[dt.datetime, dt.datetime]]:
+ """Generate time windows between start and end time."""
+ windows = []
+ current = start_time
+
+ while current < end_time:
+ window_end = min(current + dt.timedelta(seconds=window_seconds), end_time)
+ windows.append((current, window_end))
+ current = window_end
+
+ return windows
+
+ @temporalio.workflow.run
+ async def run(self, inputs: CompareSampledRecordingEventsWorkflowInputs):
+ """Run the comparison of recording events."""
+ await temporalio.workflow.wait_condition(lambda: not self.paused)
greptile
logic: Workflow could deadlock if paused and never unpaused since there's no timeout on the wait_condition
suggested fix
+ await temporalio.workflow.wait_condition(lambda: not self.paused, timeout=dt.timedelta(hours=24)) # Add timeout to prevent indefinite hang
diff block
}
}
-func (p *Processor) storeSignedVAA(v *vaa.VAA) error {
+// storeSignedVAA schedules a database update for a VAA.
+func (p *Processor) storeSignedVAA(v *vaa.VAA) {
if v.EmitterChain == vaa.ChainIDPythNet {
key := fmt.Sprintf("%v/%v", v.EmitterAddress, v.Sequence)
p.pythnetVaas[key] = PythNetVaaEntry{v: v, updateTime: time.Now()}
- return nil
+ return
}
- return p.db.StoreSignedVAA(v)
+ key := fmt.Sprintf("%d/%v/%v", v.EmitterChain, v.EmitterAddress, v.Sequence)
+ p.updateVAALock.Lock()
+ p.updatedVAAs[key] = &updateVaaEntry{v: v, dirty: true}
+ p.updateVAALock.Unlock()
}
greptile
style: Lock scope could be reduced by using a deferred unlock pattern to prevent potential deadlocks
diff block
let r = PENDING_DELETE_FILES.read().await;
files.retain(|file| !r.contains(file));
drop(r);
+
+ let r = REMOVING_FILES.read().await;
+ files.retain(|file| !r.contains(file));
+ drop(r);
greptile
logic: Potential deadlock if REMOVING_FILES and PENDING_DELETE_FILES are acquired in different orders elsewhere in the codebase. Consider combining the filters into a single pass.
diff block
throw new OperationOutcomeError(forbidden);
}
+ // START PRE-COMMIT BOT CHECK
+
+ const projectId = this.context.projects?.[0];
+ if (projectId) {
+ const systemRepo = getSystemRepo();
+ const logger = getLogger();
+ const subscriptions = await systemRepo.searchResources<Subscription>({
+ resourceType: 'Subscription',
+ count: 1000,
+ filters: [
+ {
+ code: '_project',
+ operator: Operator.EQUALS,
+ value: projectId,
+ },
+ {
+ code: 'status',
+ operator: Operator.EQUALS,
+ value: 'active',
+ },
+ ],
+ });
+
+ for (const subscription of subscriptions) {
+ // Only consider pre-commit subscriptions
+ if (!getExtension(subscription, 'https://medplum.com/fhir/StructureDefinition/pre-commit-bot')?.valueBoolean) {
+ continue;
+ }
+
+ // Check subscription criteria
+ if (
+ !resourceMatchesSubscriptionCriteria({
+ resource,
+ subscription,
+ logger,
+ context: { interaction: 'update' },
+ getPreviousResource: async () => undefined,
+ })
+ ) {
+ continue;
+ }
+
+ // URL should be a Bot reference string
+ const url = subscription.channel?.endpoint;
+ if (!url?.startsWith('Bot/')) {
+ // Skip if the URL is not a Bot reference
+ continue;
+ }
+
+ const bot = await systemRepo.readReference<Bot>({ reference: url });
+ const runAs = await findProjectMembership(projectId, createReference(bot));
+ if (!runAs) {
+ // Skip if the Bot is not in the project
+ continue;
+ }
+
+ const botResult = await executeBot({
+ subscription,
+ bot,
+ runAs,
+ input: resource,
+ contentType: ContentType.FHIR_JSON,
+ requestTime: new Date().toISOString(),
+ });
greptile
style: Bot execution inside a transaction could lead to deadlocks or timeouts. Consider executing bots before starting the transaction
diff block
+use std::{collections::HashMap, sync::Arc, time::Duration};
+
+use futures_util::{StreamExt, TryStreamExt};
+use sqlx::{
+ migrate::MigrateDatabase,
+ sqlite::{SqliteConnectOptions, SqlitePoolOptions},
+ Sqlite,
+};
+use tokio::sync::Mutex;
+use uuid::Uuid;
+
+use crate::Error;
+
+pub type SqlitePool = sqlx::SqlitePool;
+
+#[derive(Clone)]
+pub struct SqlitePoolManager {
+ // TODO: Somehow remove old pools
+ pools: Arc<Mutex<HashMap<Uuid, SqlitePool>>>,
+}
+
+impl SqlitePoolManager {
+ pub fn new() -> Self {
+ SqlitePoolManager {
+ pools: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ /// Get or creates an sqlite pool for the given key
+ pub async fn get(&self, key: Uuid) -> Result<SqlitePool, Error> {
+ let mut pools_guard = self.pools.lock().await;
+
+ let pool = if let Some(pool) = pools_guard.get(&key) {
+ pool.clone()
+ } else {
+ // TODO: Hardcoded for testing
+ let db_url = format!("sqlite:///home/rivet/rivet-ee/oss/packages/common/chirp-workflow/core/tests/db/{key}.db");
+
+ tracing::debug!(?key, "sqlite connecting");
+
+ // Init if doesn't exist
+ if !Sqlite::database_exists(&db_url)
+ .await
+ .map_err(Error::BuildSqlx)?
+ {
+ Sqlite::create_database(&db_url)
+ .await
+ .map_err(Error::BuildSqlx)?;
+ }
+
+ let opts: SqliteConnectOptions = db_url.parse().map_err(Error::BuildSqlx)?;
+
+ let pool = SqlitePoolOptions::new()
+ .max_lifetime_jitter(Duration::from_secs(90))
+ // Open connection immediately on startup
+ .min_connections(1)
+ .connect_with(opts)
+ .await
+ .map_err(Error::BuildSqlx)?;
+
+ // Run at the start of every connection
+ setup_pragma(&pool).await.map_err(Error::BuildSqlx)?;
+
+ pools_guard.insert(key, pool.clone());
+
+ tracing::debug!(?key, "sqlite connected");
+
+ pool
+ };
+
+ Ok(pool)
+ }
+}
+
+async fn setup_pragma(pool: &SqlitePool) -> Result<(), sqlx::Error> {
+ // Has to be String instead of static str due to a weird compiler bug. This crate will compile just fine
+ // but chirp-workflow will not and the error has nothing to do with this code
+ let settings = [
+ // Set the journal mode to Write-Ahead Logging 2 for concurrency
+ "PRAGMA journal_mode = WAL2".to_string(),
+ // Set synchronous mode to NORMAL for performance and data safety balance
+ "PRAGMA synchronous = NORMAL".to_string(),
+ // Set busy timeout to 5 seconds to avoid "database is locked" errors
+ "PRAGMA busy_timeout = 5000".to_string(),
+ // Enable foreign key constraint enforcement
+ "PRAGMA foreign_keys = ON".to_string(),
+ // Enable auto vacuuming and set it to incremental mode for gradual space reclaiming
+ "PRAGMA auto_vacuum = INCREMENTAL".to_string(),
+ ];
+
+ futures_util::stream::iter(settings)
+ .map(|setting| {
+ let pool = pool.clone();
+ async move {
+ // Attempt to use an existing connection
+ let mut conn = if let Some(conn) = pool.try_acquire() {
+ conn
+ } else {
+ // Create a new connection
+ pool.acquire().await?
+ };
greptile
logic: Connection acquisition should have a timeout to prevent deadlocks
diff block
+use std::{
+ future::Future,
+ ops::{Deref, DerefMut},
+};
+
+use sqlx::sqlite::SqliteQueryResult;
+use sqlx::{Executor, SqliteConnection};
+
+pub trait SqliteConnectionExt {
+ fn begin_immediate(&mut self) -> impl Future<Output = sqlx::Result<Transaction>>;
+}
+
+impl SqliteConnectionExt for SqliteConnection {
+ async fn begin_immediate(&mut self) -> sqlx::Result<Transaction> {
+ let conn = &mut *self;
+
+ conn.execute("BEGIN IMMEDIATE;").await?;
+
+ Ok(Transaction {
+ conn,
+ is_open: true,
+ })
+ }
+}
+
+pub struct Transaction<'c> {
+ conn: &'c mut SqliteConnection,
+ /// is the transaction open?
+ is_open: bool,
+}
+
+impl<'c> Transaction<'c> {
+ pub async fn commit(mut self) -> sqlx::Result<SqliteQueryResult> {
+ let res = self.conn.execute("COMMIT;").await;
+
+ if res.is_ok() {
+ self.is_open = false;
+ }
+
+ res
+ }
+}
+
+impl<'c> Drop for Transaction<'c> {
+ fn drop(&mut self) {
+ if self.is_open {
+ let handle = tokio::runtime::Handle::current();
+ handle.block_on(async move {
+ let _ = self.execute("ROLLBACK").await;
+ });
+ }
greptile
logic: blocking in Drop can cause deadlocks or panics if the runtime is shutting down. Consider using try_block_on or logging the error if rollback fails
diff block
let m_context = context.clone();
let m_cache = cache.clone();
let handle = tokio::spawn(async move {
- let mut tries = 0;
+ let mut tries: u64 = 0;
// We occasionally enocounter deadlocks while issuing updates, so we retry a few times, and
// if we still fail, we drop the batch and clear it's content from the cached update set, because
greptile
syntax: Typo: 'enocounter' should be 'encounter'
suggested fix
+ // We occasionally encounter deadlocks while issuing updates, so we retry a few times, and
+ // if we still fail, we drop the batch and clear it's content from the cached update set, because
Want to avoid this bug in your codebase? Try Greptile.
Avoid this bug!