feat: authenticate WebSocket upgrades in authenticated mode

Resolve Better Auth sessions from raw headers for WS upgrade
requests. Verify instance admin or company membership before
allowing live-events connections in authenticated mode.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Forgotten
2026-02-25 08:39:20 -06:00
parent 1c2873d22a
commit 32cbdbc0b9
3 changed files with 90 additions and 13 deletions

View File

@@ -1,4 +1,5 @@
import type { Request, RequestHandler } from "express"; import type { Request, RequestHandler } from "express";
import type { IncomingHttpHeaders } from "node:http";
import { betterAuth } from "better-auth"; import { betterAuth } from "better-auth";
import { drizzleAdapter } from "better-auth/adapters/drizzle"; import { drizzleAdapter } from "better-auth/adapters/drizzle";
import { toNodeHandler } from "better-auth/node"; import { toNodeHandler } from "better-auth/node";
@@ -24,9 +25,9 @@ export type BetterAuthSessionResult = {
type BetterAuthInstance = ReturnType<typeof betterAuth>; type BetterAuthInstance = ReturnType<typeof betterAuth>;
function headersFromExpressRequest(req: Request): Headers { function headersFromNodeHeaders(rawHeaders: IncomingHttpHeaders): Headers {
const headers = new Headers(); const headers = new Headers();
for (const [key, raw] of Object.entries(req.headers)) { for (const [key, raw] of Object.entries(rawHeaders)) {
if (!raw) continue; if (!raw) continue;
if (Array.isArray(raw)) { if (Array.isArray(raw)) {
for (const value of raw) headers.append(key, value); for (const value of raw) headers.append(key, value);
@@ -37,6 +38,10 @@ function headersFromExpressRequest(req: Request): Headers {
return headers; return headers;
} }
function headersFromExpressRequest(req: Request): Headers {
return headersFromNodeHeaders(req.headers);
}
export function createBetterAuthInstance(db: Db, config: Config): BetterAuthInstance { export function createBetterAuthInstance(db: Db, config: Config): BetterAuthInstance {
const baseUrl = config.authBaseUrlMode === "explicit" ? config.authPublicBaseUrl : undefined; const baseUrl = config.authBaseUrlMode === "explicit" ? config.authPublicBaseUrl : undefined;
const secret = process.env.BETTER_AUTH_SECRET ?? process.env.PAPERCLIP_AGENT_JWT_SECRET ?? "paperclip-dev-secret"; const secret = process.env.BETTER_AUTH_SECRET ?? process.env.PAPERCLIP_AGENT_JWT_SECRET ?? "paperclip-dev-secret";
@@ -73,15 +78,15 @@ export function createBetterAuthHandler(auth: BetterAuthInstance): RequestHandle
}; };
} }
export async function resolveBetterAuthSession( export async function resolveBetterAuthSessionFromHeaders(
auth: BetterAuthInstance, auth: BetterAuthInstance,
req: Request, headers: Headers,
): Promise<BetterAuthSessionResult | null> { ): Promise<BetterAuthSessionResult | null> {
const api = (auth as unknown as { api?: { getSession?: (input: unknown) => Promise<unknown> } }).api; const api = (auth as unknown as { api?: { getSession?: (input: unknown) => Promise<unknown> } }).api;
if (!api?.getSession) return null; if (!api?.getSession) return null;
const sessionValue = await api.getSession({ const sessionValue = await api.getSession({
headers: headersFromExpressRequest(req), headers,
}); });
if (!sessionValue || typeof sessionValue !== "object") return null; if (!sessionValue || typeof sessionValue !== "object") return null;
@@ -103,3 +108,10 @@ export async function resolveBetterAuthSession(
if (!session || !user) return null; if (!session || !user) return null;
return { session, user }; return { session, user };
} }
export async function resolveBetterAuthSession(
auth: BetterAuthInstance,
req: Request,
): Promise<BetterAuthSessionResult | null> {
return resolveBetterAuthSessionFromHeaders(auth, headersFromExpressRequest(req));
}

View File

@@ -29,6 +29,7 @@ import {
createBetterAuthHandler, createBetterAuthHandler,
createBetterAuthInstance, createBetterAuthInstance,
resolveBetterAuthSession, resolveBetterAuthSession,
resolveBetterAuthSessionFromHeaders,
} from "./auth/better-auth.js"; } from "./auth/better-auth.js";
type EmbeddedPostgresInstance = { type EmbeddedPostgresInstance = {
@@ -324,6 +325,9 @@ let betterAuthHandler: ReturnType<typeof createBetterAuthHandler> | undefined;
let resolveSession: let resolveSession:
| ((req: ExpressRequest) => Promise<Awaited<ReturnType<typeof resolveBetterAuthSession>>>) | ((req: ExpressRequest) => Promise<Awaited<ReturnType<typeof resolveBetterAuthSession>>>)
| undefined; | undefined;
let resolveSessionFromHeaders:
| ((headers: Headers) => Promise<Awaited<ReturnType<typeof resolveBetterAuthSession>>>)
| undefined;
if (config.deploymentMode === "local_trusted") { if (config.deploymentMode === "local_trusted") {
await ensureLocalTrustedBoardPrincipal(db as any); await ensureLocalTrustedBoardPrincipal(db as any);
} }
@@ -338,6 +342,7 @@ if (config.deploymentMode === "authenticated") {
const auth = createBetterAuthInstance(db as any, config); const auth = createBetterAuthInstance(db as any, config);
betterAuthHandler = createBetterAuthHandler(auth); betterAuthHandler = createBetterAuthHandler(auth);
resolveSession = (req) => resolveBetterAuthSession(auth, req); resolveSession = (req) => resolveBetterAuthSession(auth, req);
resolveSessionFromHeaders = (headers) => resolveBetterAuthSessionFromHeaders(auth, headers);
await initializeBoardClaimChallenge(db as any, { deploymentMode: config.deploymentMode }); await initializeBoardClaimChallenge(db as any, { deploymentMode: config.deploymentMode });
authReady = true; authReady = true;
} }
@@ -362,7 +367,10 @@ if (listenPort !== config.port) {
logger.warn({ requestedPort: config.port, selectedPort: listenPort }, "Requested port is busy; using next free port"); logger.warn({ requestedPort: config.port, selectedPort: listenPort }, "Requested port is busy; using next free port");
} }
setupLiveEventsWebSocketServer(server, db as any, { deploymentMode: config.deploymentMode }); setupLiveEventsWebSocketServer(server, db as any, {
deploymentMode: config.deploymentMode,
resolveSessionFromHeaders,
});
if (config.heartbeatSchedulerEnabled) { if (config.heartbeatSchedulerEnabled) {
const heartbeat = heartbeatService(db as any); const heartbeat = heartbeatService(db as any);

View File

@@ -3,9 +3,10 @@ import type { IncomingMessage, Server as HttpServer } from "node:http";
import type { Duplex } from "node:stream"; import type { Duplex } from "node:stream";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import type { Db } from "@paperclip/db"; import type { Db } from "@paperclip/db";
import { agentApiKeys } from "@paperclip/db"; import { agentApiKeys, companyMemberships, instanceUserRoles } from "@paperclip/db";
import type { DeploymentMode } from "@paperclip/shared"; import type { DeploymentMode } from "@paperclip/shared";
import { WebSocket, WebSocketServer } from "ws"; import { WebSocket, WebSocketServer } from "ws";
import type { BetterAuthSessionResult } from "../auth/better-auth.js";
import { logger } from "../middleware/logger.js"; import { logger } from "../middleware/logger.js";
import { subscribeCompanyLiveEvents } from "../services/live-events.js"; import { subscribeCompanyLiveEvents } from "../services/live-events.js";
@@ -48,26 +49,76 @@ function parseBearerToken(rawAuth: string | string[] | undefined) {
return token.length > 0 ? token : null; return token.length > 0 ? token : null;
} }
function headersFromIncomingMessage(req: IncomingMessage): Headers {
const headers = new Headers();
for (const [key, raw] of Object.entries(req.headers)) {
if (!raw) continue;
if (Array.isArray(raw)) {
for (const value of raw) headers.append(key, value);
continue;
}
headers.set(key, raw);
}
return headers;
}
async function authorizeUpgrade( async function authorizeUpgrade(
db: Db, db: Db,
req: IncomingMessage, req: IncomingMessage,
companyId: string, companyId: string,
url: URL, url: URL,
deploymentMode: DeploymentMode, opts: {
deploymentMode: DeploymentMode;
resolveSessionFromHeaders?: (headers: Headers) => Promise<BetterAuthSessionResult | null>;
},
): Promise<UpgradeContext | null> { ): Promise<UpgradeContext | null> {
const queryToken = url.searchParams.get("token")?.trim() ?? ""; const queryToken = url.searchParams.get("token")?.trim() ?? "";
const authToken = parseBearerToken(req.headers.authorization); const authToken = parseBearerToken(req.headers.authorization);
const token = authToken ?? (queryToken.length > 0 ? queryToken : null); const token = authToken ?? (queryToken.length > 0 ? queryToken : null);
// Local trusted browser board context has no bearer token in V1. // Browser board context has no bearer token in local_trusted and authenticated modes.
if (!token) { if (!token) {
if (deploymentMode !== "local_trusted") { if (opts.deploymentMode === "local_trusted") {
return {
companyId,
actorType: "board",
actorId: "board",
};
}
if (opts.deploymentMode !== "authenticated" || !opts.resolveSessionFromHeaders) {
return null; return null;
} }
const session = await opts.resolveSessionFromHeaders(headersFromIncomingMessage(req));
const userId = session?.user?.id;
if (!userId) return null;
const [roleRow, memberships] = await Promise.all([
db
.select({ id: instanceUserRoles.id })
.from(instanceUserRoles)
.where(and(eq(instanceUserRoles.userId, userId), eq(instanceUserRoles.role, "instance_admin")))
.then((rows) => rows[0] ?? null),
db
.select({ companyId: companyMemberships.companyId })
.from(companyMemberships)
.where(
and(
eq(companyMemberships.principalType, "user"),
eq(companyMemberships.principalId, userId),
eq(companyMemberships.status, "active"),
),
),
]);
const hasCompanyMembership = memberships.some((row) => row.companyId === companyId);
if (!roleRow && !hasCompanyMembership) return null;
return { return {
companyId, companyId,
actorType: "board", actorType: "board",
actorId: "board", actorId: userId,
}; };
} }
@@ -97,7 +148,10 @@ async function authorizeUpgrade(
export function setupLiveEventsWebSocketServer( export function setupLiveEventsWebSocketServer(
server: HttpServer, server: HttpServer,
db: Db, db: Db,
opts: { deploymentMode: DeploymentMode }, opts: {
deploymentMode: DeploymentMode;
resolveSessionFromHeaders?: (headers: Headers) => Promise<BetterAuthSessionResult | null>;
},
) { ) {
const wss = new WebSocketServer({ noServer: true }); const wss = new WebSocketServer({ noServer: true });
const cleanupByClient = new Map<WebSocket, () => void>(); const cleanupByClient = new Map<WebSocket, () => void>();
@@ -162,7 +216,10 @@ export function setupLiveEventsWebSocketServer(
return; return;
} }
void authorizeUpgrade(db, req, companyId, url, opts.deploymentMode) void authorizeUpgrade(db, req, companyId, url, {
deploymentMode: opts.deploymentMode,
resolveSessionFromHeaders: opts.resolveSessionFromHeaders,
})
.then((context) => { .then((context) => {
if (!context) { if (!context) {
rejectUpgrade(socket, "403 Forbidden", "forbidden"); rejectUpgrade(socket, "403 Forbidden", "forbidden");