xucaiqin 4 ay önce
ebeveyn
işleme
c18a46d0fd

+ 16 - 3
sckw-ai-biz/src/main/java/com/sckw/ai/biz/controller/ChatController.java

@@ -3,17 +3,18 @@ package com.sckw.ai.biz.controller;
 import com.sckw.ai.biz.core.CommonResult;
 import com.sckw.ai.biz.pojo.ConversionVo;
 import com.sckw.ai.biz.pojo.dto.ApiPage;
-import com.sckw.ai.biz.pojo.para.ChatPara;
-import com.sckw.ai.biz.pojo.para.ConversionPara;
-import com.sckw.ai.biz.pojo.para.StopChatPara;
+import com.sckw.ai.biz.pojo.para.*;
 import com.sckw.ai.biz.service.ChatService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.annotation.Resource;
 import jakarta.validation.Valid;
 import org.springdoc.core.annotations.ParameterObject;
+import org.springframework.core.io.buffer.DataBuffer;
 import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
 import org.springframework.web.bind.annotation.*;
+import org.springframework.web.multipart.MultipartFile;
 import reactor.core.publisher.Flux;
 
 
@@ -31,6 +32,18 @@ public class ChatController {
         return chatService.chat(chatPara);
     }
 
+    @Operation(summary = "文本转语音", description = "文本转语音")
+    @PostMapping(value = "/text2Audio")
+    public ResponseEntity<Flux<DataBuffer>> text2Audio(@RequestBody @Valid TextPara chatPara) {
+        return chatService.text2Audio(chatPara);
+    }
+
+    @Operation(summary = "音频转文本", description = "音频转文本")
+    @PostMapping(value = "/audio2Text")
+    public CommonResult<TextVo> audio2Text(@RequestParam("file") MultipartFile file) {
+        return chatService.audio2Text(file);
+    }
+
     @Operation(summary = "停止响应", description = "停止响应聊天会话")
     @PostMapping(value = "/stop")
     public CommonResult<Object> stop(@RequestBody @Valid StopChatPara stopChatPara) {

+ 23 - 0
sckw-ai-biz/src/main/java/com/sckw/ai/biz/pojo/para/AuditPara.java

@@ -0,0 +1,23 @@
+package com.sckw.ai.biz.pojo.para;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
+import lombok.Getter;
+import lombok.Setter;
+import org.springframework.web.multipart.MultipartFile;
+
+import java.io.Serial;
+import java.io.Serializable;
+
+@Getter
+@Setter
+@Schema(description = "语音转文本")
+public class AuditPara implements Serializable {
+    @Serial
+    private static final long serialVersionUID = -621415158085036877L;
+
+    @Schema(description = "文件", requiredMode = Schema.RequiredMode.REQUIRED)
+    @NotNull(message = "文件不能为空")
+    private MultipartFile file;
+
+}

+ 30 - 0
sckw-ai-biz/src/main/java/com/sckw/ai/biz/pojo/para/TextPara.java

@@ -0,0 +1,30 @@
+package com.sckw.ai.biz.pojo.para;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotBlank;
+import lombok.Getter;
+import lombok.Setter;
+
+import java.io.Serial;
+import java.io.Serializable;
+
+@Getter
+@Setter
+@Schema(description = "文本转语音")
+public class TextPara implements Serializable {
+    @Serial
+    private static final long serialVersionUID = -621415158085036877L;
+
+    @Schema(description = "会话id,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id", requiredMode = Schema.RequiredMode.NOT_REQUIRED)
+    @JsonProperty("message_id")
+    @NotBlank(message = "会话id不能为空")
+    private String messageId;
+
+    @NotBlank(message = "提问不能为空")
+    @Schema(description = "用户输入/提问内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好")
+    private String text;
+
+    @Schema(description = "用户id", hidden = true)
+    private String user;
+}

+ 21 - 0
sckw-ai-biz/src/main/java/com/sckw/ai/biz/pojo/para/TextVo.java

@@ -0,0 +1,21 @@
+package com.sckw.ai.biz.pojo.para;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import lombok.Getter;
+import lombok.Setter;
+
+import java.io.Serial;
+import java.io.Serializable;
+
+@Getter
+@Setter
+@Schema(description = "文本转语音结果")
+public class TextVo implements Serializable {
+    @Serial
+    private static final long serialVersionUID = -621415158085036877L;
+
+    @Schema(description = "文本")
+    private String text;
+
+
+}

+ 11 - 8
sckw-ai-biz/src/main/java/com/sckw/ai/biz/service/AiApiEnum.java

@@ -10,18 +10,21 @@ import lombok.Getter;
 @Getter
 @AllArgsConstructor
 public enum AiApiEnum {
-    NEW_CHAT("/chat-messages", "POST", "创建会话消息"),
-    LIKE_CHAT("/messages/%s/feedbacks", "POST", "消息点赞"),//参数 message_id
-    CHAT_STOP("/chat-messages/%s/stop", "GET", "停止聊天"),//参数 task_id
-    CONVERSION_LIST("/conversations", "GET", "获取当前用户的会话列表,默认返回最近的 20 条"),
-    MESSAGES("/messages", "GET", "滚动加载形式返回历史聊天记录"),
-    DELETE_CONVERSION("/conversations/%s", "DELETE", "删除会话"),//会话id
-    FILE_PREVIEW("/files/%s/preview", "GET", "文件预览"),//文件id
-    FEEDBACKS("/app/feedbacks", "GET", "获取应用的终端用户反馈、点赞。"),
+    NEW_CHAT("/chat-messages", "POST", "创建会话消息",""),
+    LIKE_CHAT("/messages/%s/feedbacks", "POST", "消息点赞",""),//参数 message_id
+    CHAT_STOP("/chat-messages/%s/stop", "GET", "停止聊天",""),//参数 task_id
+    CONVERSION_LIST("/conversations", "GET", "获取当前用户的会话列表,默认返回最近的 20 条",""),
+    MESSAGES("/messages", "GET", "滚动加载形式返回历史聊天记录",""),
+    DELETE_CONVERSION("/conversations/%s", "DELETE", "删除会话",""),//会话id
+    FILE_PREVIEW("/files/%s/preview", "GET", "文件预览",""),//文件id
+    FEEDBACKS("/app/feedbacks", "GET", "获取应用的终端用户反馈、点赞。",""),
+    AUDIT("/audio-to-text", "POST", "语音转文字","form"),
+    TEXT("/text-to-audio", "POST", "文字转语音",""),
     ;
     private final String url;
     private final String method;
     private final String name;
+    private final String type;
 
     public String formatUrl(Object... args) {
         if (args.length == 0) {

+ 36 - 9
sckw-ai-biz/src/main/java/com/sckw/ai/biz/service/AiApiInvoker.java

@@ -8,7 +8,11 @@ import com.sckw.ai.biz.util.OkHttpUtils;
 import jakarta.annotation.PostConstruct;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
+import okhttp3.MediaType;
+import okhttp3.MultipartBody;
+import okhttp3.RequestBody;
 import org.springframework.stereotype.Component;
+import org.springframework.web.multipart.MultipartFile;
 
 import java.util.Map;
 
@@ -28,14 +32,19 @@ public class AiApiInvoker {
         AiApiInvoker.properties = chatProperties;
     }
 
+    public static String invoke(AiApiEnum apiEnum, Object[] pathParams, Map<String, String> para, MultipartFile file) {
+        return invoke(apiEnum, pathParams, null, para, file);
+    }
+
     public static String invoke(AiApiEnum apiEnum, Object[] pathParams, Map<String, String> para) {
-        return invoke(apiEnum, pathParams, null, para);
+        return invoke(apiEnum, pathParams, null, para, null);
     }
 
     public static String invoke(AiApiEnum apiEnum, Object[] pathParams, String requestBody) {
-        return invoke(apiEnum, pathParams, requestBody, null);
+        return invoke(apiEnum, pathParams, requestBody, null, null);
     }
 
+
     /**
      * 统一调用 AI 接口
      *
@@ -44,21 +53,23 @@ public class AiApiInvoker {
      * @param requestBody 请求体(POST/PUT 时使用),可为 String 或 Object(需序列化)
      * @return 响应字符串
      */
-    public static String invoke(AiApiEnum apiEnum, Object[] pathParams, String requestBody, Map<String, String> para) {
+    public static String invoke(AiApiEnum apiEnum, Object[] pathParams, String requestBody, Map<String, String> para, MultipartFile file) {
         String url = pathParams != null && pathParams.length > 0
                 ? apiEnum.formatUrl(pathParams)
                 : apiEnum.getUrl();
 
         OkHttpUtils builder = OkHttpUtils.builder()
                 .url(properties.getUrl() + url)
-                .addHeader("Authorization", properties.getHeaderPrefix() + " " + properties.getHeader())
-                .addHeader("Content-Type", "application/json");
+                .addHeader("Authorization", properties.getHeaderPrefix() + " " + properties.getHeader());
+        if (!StrUtil.equals(apiEnum.getType(), "form")) {
+            builder.addHeader("Content-Type", "application/json");
+        }
 
         String method = apiEnum.getMethod().toUpperCase();
         try {
             switch (method) {
                 case "GET" -> {
-                    if(CollUtil.isNotEmpty(para)){
+                    if (CollUtil.isNotEmpty(para)) {
                         for (Map.Entry<String, String> map : para.entrySet()) {
                             builder.addPara(map.getKey(), map.getValue());
                         }
@@ -66,10 +77,26 @@ public class AiApiInvoker {
                     return builder.get().sync();
                 }
                 case "POST" -> {
-                    if (StrUtil.isNotBlank(requestBody)) {
-                        builder.addBodyJsonStr(requestBody);
+                    if (StrUtil.equals(apiEnum.getType(), "form")) {
+                        MultipartBody.Builder formBody = new MultipartBody.Builder();
+                        formBody.setType(MultipartBody.FORM);
+                        String fileName = file.getOriginalFilename();
+                        byte[] bytes = file.getBytes();
+                        if (CollUtil.isNotEmpty(para)) {
+                            for (Map.Entry<String, String> map : para.entrySet()) {
+                                formBody.addFormDataPart(map.getKey(), map.getValue());
+                            }
+                        }
+                        formBody.addFormDataPart("file", fileName, RequestBody.create(bytes, MediaType.parse(StrUtil.isNotBlank(file.getContentType()) ? file.getContentType() : "audio/wav")));
+                        RequestBody body = formBody.build();
+                        return builder.form(body).sync();
+                    } else {
+                        if (StrUtil.isNotBlank(requestBody)) {
+                            builder.addBodyJsonStr(requestBody);
+                        }
+                        return builder.post(true).sync();
+
                     }
-                    return builder.post(true).sync();
                 }
                 case "DELETE" -> {
                     return builder.delete().sync();

+ 40 - 5
sckw-ai-biz/src/main/java/com/sckw/ai/biz/service/ChatService.java

@@ -1,6 +1,5 @@
 package com.sckw.ai.biz.service;
 
-
 import cn.hutool.core.bean.BeanUtil;
 import cn.hutool.core.collection.CollUtil;
 import com.alibaba.fastjson.JSONObject;
@@ -13,14 +12,15 @@ import com.sckw.ai.biz.core.web.LoginUserHolder;
 import com.sckw.ai.biz.pojo.ConversionVo;
 import com.sckw.ai.biz.pojo.dto.ApiPage;
 import com.sckw.ai.biz.pojo.dto.ConversionDto;
-import com.sckw.ai.biz.pojo.para.ChatInputsPara;
-import com.sckw.ai.biz.pojo.para.ChatPara;
-import com.sckw.ai.biz.pojo.para.ConversionPara;
-import com.sckw.ai.biz.pojo.para.StopChatPara;
+import com.sckw.ai.biz.pojo.para.*;
 import jakarta.annotation.PostConstruct;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.ResponseEntity;
 import org.springframework.stereotype.Service;
+import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.reactive.function.client.WebClient;
 import reactor.core.publisher.Flux;
 
@@ -116,4 +116,39 @@ public class ChatService {
         objectApiPage.setHasMore(res.isHasMore());
         return CommonResult.ok("", objectApiPage);
     }
+
+    public ResponseEntity<Flux<DataBuffer>> text2Audio(TextPara chatPara) {
+        Long userId = LoginUserHolder.getUserId();
+        if (Objects.isNull(userId)) {
+            throw new BusinessException("未登录,请先登录");
+        }
+        chatPara.setUser(String.valueOf(userId));
+        Flux<DataBuffer> audioStream = webClient.post().uri(AiApiEnum.TEXT.getUrl())
+                .bodyValue(chatPara)
+                .header("Authorization", chatProperties.getHeaderPrefix() + " " + chatProperties.getHeader())
+                .header("Content-Type", "application/json")
+                .retrieve().bodyToFlux(DataBuffer.class);
+        return ResponseEntity.ok()
+                // 关键点:告诉前端这是一个 MP3 音频
+                .header(HttpHeaders.CONTENT_TYPE, "audio/mpeg")
+                // 可选:如果你想让浏览器直接下载而不是播放,加上这一行
+                // .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"speech.mp3\"")
+                .body(audioStream);
+    }
+
+    public CommonResult<TextVo> audio2Text(MultipartFile file) {
+        Long userId = LoginUserHolder.getUserId();
+        if (Objects.isNull(userId)) {
+            throw new BusinessException("未登录,请先登录");
+        }
+
+        String invoke = AiApiInvoker.invoke(AiApiEnum.AUDIT, new Object[]{}, new HashMap<>() {{
+            put("user", String.valueOf(userId));
+            put("word_timestamps", "disabled");
+        }}, file);
+        TextVo res = JSONObject.parseObject(invoke, new TypeReference<>() {
+        });
+        return CommonResult.ok("查询成功", res);
+    }
+
 }

+ 4 - 0
sckw-ai-biz/src/main/java/com/sckw/ai/biz/util/OkHttpUtils.java

@@ -178,6 +178,10 @@ public class OkHttpUtils {
         request.url(buildUrl());
         return this;
     }
+    public OkHttpUtils form(RequestBody requestBody) {
+        request = new Request.Builder().post(requestBody).url(buildUrl());
+        return this;
+    }
     /**
      * 初始化post方法
      */