feat: classify on article creation
This commit is contained in:
@@ -12,6 +12,7 @@ import { cn } from "@basango/ui/lib/utils";
|
|||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
|
|
||||||
|
import { Show } from "#dashboard/components/shell/show";
|
||||||
import { useTRPC } from "#dashboard/trpc/client";
|
import { useTRPC } from "#dashboard/trpc/client";
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
@@ -25,7 +26,6 @@ export function CategoriesCarousel({ onSelect, selectedCategory }: Props) {
|
|||||||
const trpc = useTRPC();
|
const trpc = useTRPC();
|
||||||
const { data, isLoading } = useQuery(trpc.categories.list.queryOptions());
|
const { data, isLoading } = useQuery(trpc.categories.list.queryOptions());
|
||||||
const categories = data ?? [];
|
const categories = data ?? [];
|
||||||
const showSkeletons = isLoading && categories.length === 0;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
@@ -43,22 +43,25 @@ export function CategoriesCarousel({ onSelect, selectedCategory }: Props) {
|
|||||||
All
|
All
|
||||||
</CategoryPill>
|
</CategoryPill>
|
||||||
</CarouselItem>
|
</CarouselItem>
|
||||||
{showSkeletons
|
<Show
|
||||||
? Array.from({ length: PLACEHOLDER_COUNT }).map((_, index) => (
|
fallback={Array.from({ length: PLACEHOLDER_COUNT }).map((_, index) => (
|
||||||
<CarouselItem className="basis-auto pl-2" key={`category-skeleton-${index}`}>
|
<CarouselItem className="basis-auto pl-2" key={`category-skeleton-${index}`}>
|
||||||
<Skeleton className="h-8 w-20 rounded-full bg-muted/70" />
|
<Skeleton className="h-8 w-20 rounded-full bg-muted/70" />
|
||||||
</CarouselItem>
|
</CarouselItem>
|
||||||
))
|
))}
|
||||||
: categories.map((category) => (
|
when={isLoading && categories.length > 0}
|
||||||
<CarouselItem className="basis-auto pl-2" key={category.id}>
|
>
|
||||||
<CategoryPill
|
{categories.map((category) => (
|
||||||
active={selectedCategory === category.id}
|
<CarouselItem className="basis-auto pl-2" key={category.id}>
|
||||||
onClick={() => onSelect(category.id)}
|
<CategoryPill
|
||||||
>
|
active={selectedCategory === category.id}
|
||||||
{category.name}
|
onClick={() => onSelect(category.id)}
|
||||||
</CategoryPill>
|
>
|
||||||
</CarouselItem>
|
{category.name}
|
||||||
))}
|
</CategoryPill>
|
||||||
|
</CarouselItem>
|
||||||
|
))}
|
||||||
|
</Show>
|
||||||
</CarouselContent>
|
</CarouselContent>
|
||||||
<CarouselPrevious className="hidden md:flex" size="icon" />
|
<CarouselPrevious className="hidden md:flex" size="icon" />
|
||||||
<CarouselNext className="hidden md:flex" size="icon" />
|
<CarouselNext className="hidden md:flex" size="icon" />
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import * as uuid from "uuid";
|
|||||||
import { Database } from "#db/client";
|
import { Database } from "#db/client";
|
||||||
import { getSourceIdByName } from "#db/queries/sources";
|
import { getSourceIdByName } from "#db/queries/sources";
|
||||||
import { articles, categories, sources } from "#db/schema";
|
import { articles, categories, sources } from "#db/schema";
|
||||||
|
import { classifyCategory } from "#db/services/category-classifier";
|
||||||
import { CreateArticleParams, GetArticlesParams } from "#db/types/articles";
|
import { CreateArticleParams, GetArticlesParams } from "#db/types/articles";
|
||||||
import { GetDistributionsParams, GetPublicationsParams } from "#db/types/shared";
|
import { GetDistributionsParams, GetPublicationsParams } from "#db/types/shared";
|
||||||
import {
|
import {
|
||||||
@@ -41,24 +42,26 @@ export async function createArticle(db: Database, params: CreateArticleParams) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const categoryList = params.categories ?? [];
|
|
||||||
const data = {
|
const data = {
|
||||||
...params,
|
...params,
|
||||||
categories: categoryList,
|
categories: params.categories ?? [],
|
||||||
hash: md5(params.link),
|
hash: md5(params.link),
|
||||||
|
id: uuid.v7(),
|
||||||
readingTime: computeReadingTime(params.body),
|
readingTime: computeReadingTime(params.body),
|
||||||
sentiment: (params.sentiment ?? "neutral") as Sentiment,
|
sentiment: (params.sentiment ?? "neutral") as Sentiment,
|
||||||
sourceId: await getSourceIdByName(db, params.sourceId),
|
sourceId: await getSourceIdByName(db, params.sourceId),
|
||||||
tokenStatistics: computeTokenStatistics({
|
tokenStatistics: computeTokenStatistics({
|
||||||
body: params.body,
|
body: params.body,
|
||||||
categories: categoryList,
|
categories: params.categories ?? [],
|
||||||
title: params.title,
|
title: params.title,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
data.categoryId = classifyCategory(data).category.id;
|
||||||
|
|
||||||
const [result] = await db
|
const [result] = await db
|
||||||
.insert(articles)
|
.insert(articles)
|
||||||
.values({ id: uuid.v7(), ...data })
|
.values({ ...data })
|
||||||
.returning({
|
.returning({
|
||||||
id: articles.id,
|
id: articles.id,
|
||||||
sourceId: articles.sourceId,
|
sourceId: articles.sourceId,
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ export class CategoryClassifier {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function classifyCategory(article: ArticleCategories): CategoryScore {
|
export function classifyCategory(article: ArticleCategories): CategoryScore {
|
||||||
const rawCategories = article.categories ?? [];
|
const rawCategories = article.categories ?? [];
|
||||||
const normalizedCategories = Array.from(
|
const normalizedCategories = Array.from(
|
||||||
new Set(
|
new Set(
|
||||||
|
|||||||
Reference in New Issue
Block a user