feat(api): authentication
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import { Database, db } from "@basango/db/client";
|
||||
import { initTRPC } from "@trpc/server";
|
||||
import { TRPCError, initTRPC } from "@trpc/server";
|
||||
import type { Context } from "hono";
|
||||
import superjson from "superjson";
|
||||
|
||||
import { withAuthentication } from "#api/trpc/middlewares/auth";
|
||||
import { withDatabase } from "#api/trpc/middlewares/db";
|
||||
import { Session, verifyAccessToken } from "#api/utils/auth";
|
||||
import { Session, getSession } from "#api/utils/auth";
|
||||
import { getGeoContext } from "#api/utils/geo";
|
||||
|
||||
type TRPCContext = {
|
||||
@@ -16,7 +16,7 @@ type TRPCContext = {
|
||||
|
||||
export const createTRPCContext = async (_: unknown, c: Context): Promise<TRPCContext> => {
|
||||
const accessToken = c.req.header("Authorization")?.split(" ")[1];
|
||||
const session = await verifyAccessToken(accessToken);
|
||||
const session = await getSession(db, accessToken);
|
||||
const geo = getGeoContext(c.req);
|
||||
|
||||
return {
|
||||
@@ -51,13 +51,13 @@ export const publicProcedure = t.procedure.use(withDatabaseMiddleware);
|
||||
|
||||
export const protectedProcedure = t.procedure
|
||||
.use(withDatabaseMiddleware)
|
||||
.use(withAutenticationMiddleware) // NOTE: This is needed to ensure that the teamId is set in the context
|
||||
.use(withAutenticationMiddleware)
|
||||
.use(async (opts) => {
|
||||
const { session } = opts.ctx;
|
||||
|
||||
// if (!session) {
|
||||
// throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
// }
|
||||
if (!session) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Database } from "@basango/db/client";
|
||||
|
||||
// import { TRPCError } from "@trpc/server";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
|
||||
import type { Session } from "#api/utils/auth";
|
||||
|
||||
@@ -18,14 +17,12 @@ export const withAuthentication = async <TReturn>(opts: {
|
||||
}) => {
|
||||
const { ctx, next } = opts;
|
||||
|
||||
// const userId = ctx.session?.user?.id;
|
||||
|
||||
// if (!userId) {
|
||||
// throw new TRPCError({
|
||||
// code: "UNAUTHORIZED",
|
||||
// message: "No permission to access",
|
||||
// });
|
||||
// }
|
||||
if (!ctx.session) {
|
||||
throw new TRPCError({
|
||||
code: "UNAUTHORIZED",
|
||||
message: "Authentication is required to access this resource.",
|
||||
});
|
||||
}
|
||||
|
||||
return next({
|
||||
ctx: {
|
||||
|
||||
@@ -2,10 +2,12 @@ import type { inferRouterInputs, inferRouterOutputs } from "@trpc/server";
|
||||
|
||||
import { createTRPCRouter } from "#api/trpc/init";
|
||||
import { articlesRouter } from "#api/trpc/routers/articles";
|
||||
import { authRouter } from "#api/trpc/routers/auth";
|
||||
import { sourcesRouter } from "#api/trpc/routers/sources";
|
||||
|
||||
export const appRouter = createTRPCRouter({
|
||||
articles: articlesRouter,
|
||||
auth: authRouter,
|
||||
sources: sourcesRouter,
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import { getUserByEmail, getUserById } from "@basango/db/queries";
|
||||
import { loginSchema, refreshSessionSchema } from "@basango/domain/models";
|
||||
import { verifyPassword } from "@basango/encryption";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "#api/trpc/init";
|
||||
import { createSessionTokens, verifyRefreshToken } from "#api/utils/auth";
|
||||
|
||||
export const authRouter = createTRPCRouter({
|
||||
login: publicProcedure.input(loginSchema).mutation(async ({ ctx, input }) => {
|
||||
const user = await getUserByEmail(ctx.db, input.email);
|
||||
|
||||
if (!user || user.isLocked) {
|
||||
throw new TRPCError({
|
||||
code: "UNAUTHORIZED",
|
||||
message: "Invalid credentials.",
|
||||
});
|
||||
}
|
||||
|
||||
const isValidPassword = await verifyPassword(input.password, user.password);
|
||||
|
||||
if (!isValidPassword) {
|
||||
throw new TRPCError({
|
||||
code: "UNAUTHORIZED",
|
||||
message: "Invalid credentials.",
|
||||
});
|
||||
}
|
||||
|
||||
const session = {
|
||||
user: {
|
||||
email: user.email,
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
},
|
||||
};
|
||||
|
||||
const tokens = await createSessionTokens(session);
|
||||
|
||||
return {
|
||||
...tokens,
|
||||
user: session.user,
|
||||
};
|
||||
}),
|
||||
|
||||
refresh: publicProcedure.input(refreshSessionSchema).mutation(async ({ ctx, input }) => {
|
||||
const session = await verifyRefreshToken(input.refreshToken);
|
||||
|
||||
if (!session) {
|
||||
throw new TRPCError({
|
||||
code: "UNAUTHORIZED",
|
||||
message: "Invalid refresh token.",
|
||||
});
|
||||
}
|
||||
|
||||
const user = await getUserById(ctx.db, {
|
||||
email: session.user.email,
|
||||
id: session.user.id,
|
||||
});
|
||||
|
||||
if (!user || user.isLocked) {
|
||||
throw new TRPCError({
|
||||
code: "UNAUTHORIZED",
|
||||
message: "Invalid refresh token.",
|
||||
});
|
||||
}
|
||||
|
||||
const tokens = await createSessionTokens({
|
||||
user: {
|
||||
email: user.email,
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...tokens,
|
||||
user: {
|
||||
email: user.email,
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
session: protectedProcedure.query(({ ctx }) => ctx.session.user),
|
||||
});
|
||||
+126
-9
@@ -1,4 +1,12 @@
|
||||
import { type JWTPayload, jwtVerify } from "jose";
|
||||
import { Database } from "@basango/db/client";
|
||||
import { getUserById } from "@basango/db/queries";
|
||||
import {
|
||||
DEFAULT_ACCESS_TOKEN_TTL,
|
||||
DEFAULT_REFRESH_TOKEN_TTL,
|
||||
DEFAULT_TOKEN_AUDIENCE,
|
||||
DEFAULT_TOKEN_ISSUER,
|
||||
} from "@basango/domain/constants";
|
||||
import { type JWTPayload, SignJWT, jwtVerify } from "jose";
|
||||
|
||||
import { env } from "#api/config";
|
||||
|
||||
@@ -6,35 +14,144 @@ export type Session = {
|
||||
user: {
|
||||
id: string;
|
||||
email: string;
|
||||
full_name?: string;
|
||||
name?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type VerifiedJWTPayload = JWTPayload & {
|
||||
tokenType: TokenType;
|
||||
user: {
|
||||
id: string;
|
||||
email: string;
|
||||
full_name?: string;
|
||||
name?: string;
|
||||
};
|
||||
};
|
||||
|
||||
type TokenType = "access" | "refresh";
|
||||
|
||||
export type SessionTokens = {
|
||||
accessToken: string;
|
||||
refreshToken: string;
|
||||
accessTokenExpiresAt: string;
|
||||
refreshTokenExpiresAt: string;
|
||||
};
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
function getSecretKey() {
|
||||
return encoder.encode(env("BASANGO_JWT_SECRET"));
|
||||
}
|
||||
|
||||
export async function getSession(db: Database, accessToken?: string): Promise<Session | null> {
|
||||
const session = await verifyAccessToken(accessToken);
|
||||
|
||||
if (!session) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const user = await getUserById(db, {
|
||||
email: session.user.email,
|
||||
id: session.user.id,
|
||||
});
|
||||
|
||||
if (!user || user.isLocked) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
user: {
|
||||
email: user.email,
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function createToken(session: Session, tokenType: TokenType, expiresIn: string) {
|
||||
return new SignJWT({
|
||||
tokenType,
|
||||
user: session.user,
|
||||
})
|
||||
.setProtectedHeader({ alg: "HS256" })
|
||||
.setIssuedAt()
|
||||
.setAudience(DEFAULT_TOKEN_AUDIENCE)
|
||||
.setIssuer(DEFAULT_TOKEN_ISSUER)
|
||||
.setExpirationTime(expiresIn)
|
||||
.sign(getSecretKey());
|
||||
}
|
||||
|
||||
export async function createSessionTokens(session: Session): Promise<SessionTokens> {
|
||||
const [accessToken, refreshToken] = await Promise.all([
|
||||
createToken(session, "access", DEFAULT_ACCESS_TOKEN_TTL),
|
||||
createToken(session, "refresh", DEFAULT_REFRESH_TOKEN_TTL),
|
||||
]);
|
||||
|
||||
const issuedAt = Date.now();
|
||||
const accessTokenExpiresAt = new Date(
|
||||
issuedAt + formatTTL(DEFAULT_ACCESS_TOKEN_TTL),
|
||||
).toISOString();
|
||||
const refreshTokenExpiresAt = new Date(
|
||||
issuedAt + formatTTL(DEFAULT_REFRESH_TOKEN_TTL),
|
||||
).toISOString();
|
||||
|
||||
return {
|
||||
accessToken,
|
||||
accessTokenExpiresAt,
|
||||
refreshToken,
|
||||
refreshTokenExpiresAt,
|
||||
};
|
||||
}
|
||||
|
||||
export async function verifyAccessToken(accessToken?: string): Promise<Session | null> {
|
||||
if (!accessToken) return null;
|
||||
return verifyToken(accessToken, "access");
|
||||
}
|
||||
|
||||
export async function verifyRefreshToken(refreshToken?: string): Promise<Session | null> {
|
||||
return verifyToken(refreshToken, "refresh");
|
||||
}
|
||||
|
||||
async function verifyToken(
|
||||
token: string | undefined,
|
||||
expectedType: TokenType,
|
||||
): Promise<Session | null> {
|
||||
if (!token) return null;
|
||||
|
||||
try {
|
||||
const { payload } = await jwtVerify<VerifiedJWTPayload>(
|
||||
accessToken,
|
||||
new TextEncoder().encode(env("BASANGO_JWT_SECRET")),
|
||||
);
|
||||
const { payload } = await jwtVerify<VerifiedJWTPayload>(token, getSecretKey(), {
|
||||
audience: DEFAULT_TOKEN_AUDIENCE,
|
||||
issuer: DEFAULT_TOKEN_ISSUER,
|
||||
});
|
||||
|
||||
if (payload.tokenType !== expectedType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
user: {
|
||||
email: payload.user.email,
|
||||
full_name: payload.user.full_name,
|
||||
id: payload.user.id,
|
||||
name: payload.user.name,
|
||||
},
|
||||
};
|
||||
} catch (_error: unknown) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function formatTTL(ttl: string) {
|
||||
const match = ttl.match(/^(\d+)([smhd])$/);
|
||||
if (!match) return 0;
|
||||
const [, rawValue, rawUnit] = match;
|
||||
if (!rawValue || !rawUnit) {
|
||||
return 0;
|
||||
}
|
||||
const value = Number.parseInt(rawValue, 10);
|
||||
const multipliers = {
|
||||
d: 86_400_000,
|
||||
h: 3_600_000,
|
||||
m: 60_000,
|
||||
s: 1_000,
|
||||
} as const;
|
||||
const unit = rawUnit as keyof typeof multipliers;
|
||||
return value * (multipliers[unit] ?? 1_000);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user