2024-08-30
Go & 后端
00

目录

1 修改问题接口
2 代码提交接口
3 引入 ProtoBuffer 和 Grpc,做远程代码执行
3.1 proto 文件定义
3.2 Register 服务的两端实现
3.3 ExecCode 服务的客户端实现
3.4 ExecCode 服务的服务端实现

MisakaOJ项目,第五次记录内容

1 修改问题接口

提示

注意 categories 和 test_cases 两个参数的传参方式,只有这样才能被PostFormArray函数识别。

如果是这样:// @Param categories formData []string false "categories" collectionFormat(multi),那么实际传的方式是:

b70c1118c119fc85d2f62e7192d58c67.png

如果是这样:// @Param categories formData array false "categories",那么实际传的方式是:

ae0452e3472a503a630a704ef1282760.png

那么到后端的样子就是:["贪心,二叉树"]

Handler层:

go
// ModifyQuestion // @Tags 管理员接口 // @Summary 修改题目 // @Param authorization header string true "authorization" // @Param identity formData string true "identity" // @Param title formData string false "title" // @Param content formData string false "content" // @Param max_runtime formData int false "max_runtime" // @Param max_mem formData int false "max_mem" // @Param categories formData []string false "categories" collectionFormat(multi) // @Param test_cases formData []string false "test_cases" collectionFormat(multi) // @Success 200 {data} json "{"code": "200", "message": ""}" // @Router /admin/modify_question [post] func ModifyQuestion(c *gin.Context) { // 参数解析 除 identity 以外别的参数不再必须 identity := c.PostForm("identity") if identity == "" { ErrorHandler(c, constants.ParameterMissingErr.Error()+"identity") return } modifyQuestion := &models.Question{} tx := models.GetQuestionDetail(identity) // 把要修改的问题找到 e := tx.First(modifyQuestion).Error if e != nil { if errors.Is(e, gorm.ErrRecordNotFound) { ErrorHandler(c, constants.QuestionNotExistErr.Error()+"identity: "+identity) return } ErrorHandler(c, errors.Join(constants.DataBaseQueryErr, e).Error()) return } title := c.PostForm("title") if title != "" { modifyQuestion.Title = title } content := c.PostForm("content") if content != "" { modifyQuestion.Content = content } maxRuntime, e := strconv.Atoi(c.PostForm("max_runtime")) if maxRuntime < 0 { ErrorHandler(c, constants.ParameterParseErr.Error()+"max_runtime: less than 0") return } if maxRuntime != 0 { modifyQuestion.MaxRuntime = maxRuntime } maxMem, e := strconv.Atoi(c.PostForm("max_mem")) if maxMem < 0 { ErrorHandler(c, constants.ParameterParseErr.Error()+"max_mem: less than 0") return } if maxMem != 0 { modifyQuestion.MaxMem = maxMem } categoryArrayFromUser := c.PostFormArray("categories") testCaseArrayFromUser := c.PostFormArray("test_cases") // 问题分类查库 初始化 var questionCategoryArray []*models.QuestionCategory if len(categoryArrayFromUser) != 0 { questionCategoryArray = make([]*models.QuestionCategory, 0) var questionCategory *models.QuestionCategory for i := range categoryArrayFromUser { // todo 分类id和分类的关系也可以放进redis里 加快访问 // 初始化 questionCategory = &models.QuestionCategory{} questionCategory.Category = &models.Category{} // 查询 e = models.GetCategoryByColumn("name", categoryArrayFromUser[i]).First(questionCategory.Category).Error if e != nil { if errors.Is(e, gorm.ErrRecordNotFound) { continue } ErrorHandler(c, errors.Join(constants.DataBaseQueryErr, e).Error()) return } // 放结果 questionCategory.CategoryId = questionCategory.Category.ID questionCategory.QuestionId = modifyQuestion.ID questionCategoryArray = append(questionCategoryArray, questionCategory) } } // 测试用例初始化 var testCasesArray []*models.TestCase if len(testCaseArrayFromUser) != 0 { testCasesArray = make([]*models.TestCase, 0) var caseMap map[string]string var singleCase *models.TestCase var ok bool for i := range testCaseArrayFromUser { e = json.Unmarshal([]byte(testCaseArrayFromUser[i]), &caseMap) if e != nil { continue } singleCase = &models.TestCase{ QuestionIdentity: modifyQuestion.Identity, } singleCase.Input, ok = caseMap["input"] if !ok { continue } singleCase.Output, ok = caseMap["output"] if !ok { continue } testCasesArray = append(testCasesArray, singleCase) } } // 开始事务 if e = models.DB.Transaction(func(tx *gorm.DB) error { // 更新题目信息 var txErr error txErr = tx.Save(modifyQuestion).Error if txErr != nil { return txErr } // 确定是否要更新分类问题关联 if questionCategoryArray != nil { // 删除原有的分类问题关联 txErr = tx.Where("question_id = ?", modifyQuestion.ID).Delete(&models.QuestionCategory{}).Error if txErr != nil { return txErr } // 然后将新的关联插入 txErr = tx.Create(questionCategoryArray).Error if txErr != nil { return txErr } } // 确定是否要更新测试用例 if testCasesArray != nil { // 还是先删再建 txErr = tx.Where("question_identity = ?", modifyQuestion.Identity).Delete(&models.TestCase{}).Error if txErr != nil { return txErr } txErr = tx.Create(testCasesArray).Error if txErr != nil { return txErr } } return nil }); e != nil { ErrorHandler(c, errors.Join(constants.DataBaseUpdateErr, e).Error()) return } c.JSON(http.StatusOK, gin.H{ "code": 200, "message": "Modify Question Successful! ", }) }

2 代码提交接口

提示

gorm 中,Save 方法和 Update/Updates 方法有以下区别:

  • Save 方法不管传入的 model 修改了多少,会按传入的 model 对整条记录进行更新。假设从数据库中拿到一个 User 结构体,修改其中一个字段,再传给 Save,那么 Save 不仅会对被修改的字段进行更新,也会对其他的字段进行更新,即使该字段没有被修改。
  • 另外,传入 Save 方法的 model 中如果有值没有被初始化,那么数据库中对应的字段会被修改为默认值
  • Update 方法是指定一个字段进行修改。
  • Updates 方法会对传入的 model 进行检查,并且有选择地进行更新。一般情况下,会选择有值且该值被修改的字段进行更新,并且忽略未初始化的零值,不会更新整条记录

先修改一下认证中间件:

go
// AuthUserMiddleWare 验证是否为普通用户的中间件 func AuthUserMiddleWare() gin.HandlerFunc { return func(c *gin.Context) { token := c.GetHeader("Authorization") if token == "" { c.Abort() handler.ErrorHandler(c, constants.AuthorizationUserFailed.Error()) return } userIdentity, isAdmin, e := util.ParseToken(token) if e != nil { c.Abort() handler.ErrorHandler(c, e.Error()) return } if isAdmin == 0 { c.Abort() handler.ErrorHandler(c, constants.AuthorizationUserFailed.Error()) return } c.Set("userIdentity", userIdentity) // 把 identity 放进 context 里面 让后面的 handler 方便查 c.Next() } }

通过把 identity 放进 context,之后的用户相关的接口就只传一个 token 就够了。

之后修改一下对应的获取用户信息接口:

go
// GetUserDetail // @Tags 用户接口 // @Summary 获取用户信息 // @Param Authorization header string true "Authorization" // @Success 200 {data} json "{"code": "200", "data": ""}" // @Router /user/detail [get] func GetUserDetail(c *gin.Context) { i, isExist := c.Get("userIdentity") if !isExist { // 如果没传identity c.JSON(http.StatusOK, gin.H{ "code": -1, "message": "Parameter identity is missed! ", }) return } identity := i.(string) ... }

代码提交,分成几个部分:

  • 在接口里拿到提交的东西
  • 将提交的代码进行保存
  • 运行保存的代码
  • 接口获取运行结果,存到数据库里,并且返回回去

代码保存:

go
var SubmitCodeSavePath = "./code/" func SaveCodeToFile(codeBytes []byte, submitIdentity string) (string, error) { filePath := constants.SubmitCodeSavePath + submitIdentity e := os.Mkdir(filePath, os.ModePerm) // 777权限 if e != nil { return "", errors.Join(constants.MkdirErr, e) } filePath += "/main.go" file, e := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, os.ModePerm) if e != nil { return "", errors.Join(constants.NewFileErr, e) } _, e = file.Write(codeBytes) if e != nil { return "", errors.Join(constants.WriteFileErr, e) } return filePath, nil }

代码运行:

go
package codeExec import ( "MisakaOJ/models" "MisakaOJ/util" "bytes" "fmt" context2 "golang.org/x/net/context" "io" "os/exec" "strings" "sync" "time" ) // 一个解决不了的问题是 开了那么多协程 执行器却只有那几个 // 协程过多 执行器过少 调度时总会有几个先挂起等待下一次调度 // 这就导致了时间测量不准 在 testCase 过多的时候极其容易超时 // 还有一个问题是,go run 是连编译带运行一起 测量内存只能有一些参考价值 // 两数相加都要30MB起步 只能说是仅作为参考了 // 还有 这东西运行有点不稳定 我查了几次 不清楚是不是 go 本身的问题 有的指针都不知道指到哪里去了 堆栈上面全是报错 // 只能说是凑合用的东西 // CodeResultStatus 用于标识代码执行情况 type CodeResultStatus int const ( Finish CodeResultStatus = iota WrongAnswer OutOfMemory OutOfTime CompileError ExecError UnexpectedExecError ) type CodeExecResult struct { status CodeResultStatus message string caseIndex int } func ExecCode(codePath string, testCases []*models.TestCase, questionMaxTime int, questionMaxMemory int64) *CodeExecResult { var context context2.Context var cancelFunc context2.CancelFunc resultChannel := make(chan *CodeExecResult, len(testCases)/2+1) context, cancelFunc = context2.WithTimeout(context2.Background(), time.Duration(questionMaxTime)*time.Millisecond) // 按毫秒计时 wg := &sync.WaitGroup{} defer func() { cancelFunc() close(resultChannel) }() for i := range testCases { go ExecCodePerCase(context, resultChannel, codePath, testCases[i].Input, testCases[i].Output, i, questionMaxMemory, wg) wg.Add(1) } var result *CodeExecResult for _ = range len(testCases) { select { case <-time.After(time.Duration(questionMaxTime+1000) * time.Millisecond): // 超时 在原有的基础上再多等一秒 确保超时的报错信息也能传进来 break case r := <-resultChannel: if r.status != Finish { result = r break } } } if result == nil { result = &CodeExecResult{status: Finish} } wg.Wait() return result } func ExecCodePerCase(context context2.Context, resultChannel chan<- *CodeExecResult, codePath, input, output string, caseIndex int, questionMaxMemory int64, wg *sync.WaitGroup) { defer func() { wg.Done() if e := recover(); e != nil { resultChannel <- &CodeExecResult{ status: UnexpectedExecError, message: fmt.Sprintf("%v", e), caseIndex: caseIndex, } } }() cmd := exec.Command("go", "run", codePath) // 构建执行命令 var out, err bytes.Buffer var inPipe io.WriteCloser var e error cmdChannel := make(chan error, 1) defer close(cmdChannel) // 拿到对应的输入 输出 报错信息 inPipe, e = cmd.StdinPipe() if e != nil { resultChannel <- &CodeExecResult{ status: UnexpectedExecError, message: e.Error(), caseIndex: caseIndex, } return } cmd.Stdout = &out cmd.Stderr = &err // 输入 _, e = io.WriteString(inPipe, input) if e != nil { resultChannel <- &CodeExecResult{ status: UnexpectedExecError, message: e.Error(), caseIndex: caseIndex, } return } go func() { _ = cmd.Start() cmdChannel <- cmd.Wait() }() go func() { for { if cmd.Process != nil { break } } memCost, e := util.GetPeakWorkingSetByPid(cmd.Process.Pid) if e != nil { resultChannel <- &CodeExecResult{ status: UnexpectedExecError, message: "Cannot Measure Memory Usage: " + e.Error(), caseIndex: caseIndex, } return } if memCost > questionMaxMemory { resultChannel <- &CodeExecResult{ status: OutOfMemory, message: fmt.Sprintf("Memory Usage: %d", memCost), caseIndex: caseIndex, } return } }() select { case <-context.Done(): // 超时 关闭子进程 _ = cmd.Process.Kill() resultChannel <- &CodeExecResult{ status: OutOfTime, caseIndex: caseIndex, } return case e = <-cmdChannel: // 正常结束 } if e != nil { errString := err.String() // 如果是编译错误 那么一般会有 # command-line-arguments 前缀 并且退出代码为一般为1 // 如果是 panic 那么一般报错中会有 panic 并且退出代码为一般为2 // 注意 看退出码实在不靠谱 还是字符串检测吧 if strings.Contains(errString, "# command-line-arguments") { resultChannel <- &CodeExecResult{ status: CompileError, message: "Compile Error: %d" + errString, caseIndex: caseIndex, } } else if strings.Contains(errString, "panic") { resultChannel <- &CodeExecResult{ status: ExecError, message: "Execute Error: %d" + errString, caseIndex: caseIndex, } } else { resultChannel <- &CodeExecResult{ status: UnexpectedExecError, message: e.Error() + " " + errString, caseIndex: caseIndex, } } return } if out.String() != output { resultChannel <- &CodeExecResult{ status: WrongAnswer, message: "Wrong Answer: " + out.String(), caseIndex: caseIndex, } return } resultChannel <- &CodeExecResult{ status: Finish, caseIndex: caseIndex, } return } // GetPeakWorkingSetByPid 函数如下: // GetPeakWorkingSetByPid 按照pid寻找进程 获取该进程的最高内存占用 该函数耗时约300毫秒 func GetPeakWorkingSetByPid(pid int) (int64, error) { cmd := exec.Command("powershell", "-Command", fmt.Sprintf("Get-Process -Id %d | Select-Object -Property PeakWorkingSet64", pid)) output, e := cmd.Output() if e != nil { return 0, e } lines := strings.Split(string(output), "\r") if len(lines) < 3 { return 0, fmt.Errorf("unexpected output from powershell") } // 获取到的结果类似这样 // PeakWorkingSet64 // -------------- // 4025053184 result, e := strconv.ParseInt(strings.TrimSpace(lines[3]), 10, 64) return result, e }

最后是 Handler 的部分。

go
// NewSubmit // @Tags 用户接口 // @Summary 提交代码 // @Param Authorization header string true "Authorization" // @Param question_identity query string true "question_identity" // @Param code body string true "code" // @Success 200 {data} json "{"code": "200", "data": ""}" // @Router /user/submit [post] func NewSubmit(c *gin.Context) { i, isExist := c.Get("userIdentity") if !isExist { ErrorHandler(c, constants.ParameterMissingErr.Error()+"authorization") return } userIdentity := i.(string) // 拿参数 questionIdentity := c.Query("question_identity") if questionIdentity == "" { ErrorHandler(c, constants.ParameterMissingErr.Error()+"question_identity") return } codeBytes, e := io.ReadAll(c.Request.Body) if e != nil { ErrorHandler(c, errors.Join(constants.ReadRequestBodyErr, e).Error()) return } // 保存提交的code submitIdentity := util.GenerateUUID() codePath, e := util.SaveCodeToFile(codeBytes, submitIdentity) if e != nil { ErrorHandler(c, e.Error()) return } // 拿题目数据和测试用例 question := &models.Question{} e = models.GetQuestionDetail(questionIdentity).Preload("TestCases").Find(question).Error if e != nil { if errors.Is(e, gorm.ErrRecordNotFound) { ErrorHandler(c, constants.ParameterParseErr.Error()+"question_identity") return } ErrorHandler(c, errors.Join(constants.DataBaseQueryErr, e).Error()) return } // 执行 var codeExecResult *codeExec.Result // 是否远程执行代码 if constants.ExecCodeRemotely { // todo 远程执行代码这块先留着 } else { codeExecResult = codeExec.ExecCode(codePath, question.TestCases, question.MaxRuntime, int64(question.MaxMem)) } // 根据结果写入数据库 返回结果 // 无论执行成功与否 要写的一定有 submit 表和 user question 表里面的 submit_num 字段 // 再根据成功与否 写 user question 表里面的 finish_question_num 字段 txErr := models.DB.Transaction(func(tx *gorm.DB) error { // 先写 submit newSubmit := &models.Submit{ Identity: util.GenerateUUID(), QuestionIdentity: questionIdentity, UserIdentity: userIdentity, Path: codePath, Status: int(codeExecResult.Status), } e = tx.Create(newSubmit).Error if e != nil { return e } // 再写 user 和 question e = tx.Model(&models.User{}).Where("identity = ?", userIdentity).Update("submit_num", gorm.Expr("submit_num + ?", 1)).Error if e != nil { return e } e = tx.Model(&models.Question{}).Where("identity = ?", questionIdentity).Update("submit_num", gorm.Expr("submit_num + ?", 1)).Error if e != nil { return e } // 再根据代码执行结果决定是否加 finish_question_num if codeExecResult.Status == codeExec.Finish { e = tx.Model(&models.User{}).Where("identity = ?", userIdentity).Update("finish_question_num", gorm.Expr("finish_question_num + ?", 1)).Error if e != nil { return e } e = tx.Model(&models.Question{}).Where("identity = ?", questionIdentity).Update("finish_num", gorm.Expr("finish_num + ?", 1)).Error if e != nil { return e } } return nil }) if txErr != nil { ErrorHandler(c, errors.Join(constants.DataBaseUpdateErr, txErr).Error()) return } c.JSON(http.StatusOK, gin.H{ "code": 200, "data": codeExecResult, }) }

效果:

image.png

image.png

image.png

3 引入 ProtoBuffer 和 Grpc,做远程代码执行

3.1 proto 文件定义

protoc 编译命令:

shell
// 编译message 没有rpc protoc --go_out=./ ./codeExec.proto // 有rpc 没有message protoc --go-grpc_out=./ ./codeExec.proto // 两个一起编译 protoc --go-grpc_out=./ --go_out=./ ./codeExec.proto

提示

这块始终有一些问题没能解决,还有后来才意识到的,其实 RPC 可以选择双向流模式而不是一元 RPC 模式,还可以用消息队列做订阅发布模式,只能说以后再说了。

先定义 proto 文件:

proto3
syntax = "proto3"; package proto; option go_package = "."; // 远程执行代码的请求 message 结构 message ExecCodeRequest { string input = 1; string output = 2; string code = 3; int32 maxMem = 4; int32 maxTime = 5; int32 test_case_index = 6; } // 远程执行代码的返回结果 message 结构 message ExecCodeResponse { int32 result_status = 1; string message = 2; int32 test_case_index = 3; } // 远程执行代码的具体的 rpc 函数 service ExecCode { rpc ExecCodeRemote(ExecCodeRequest) returns (ExecCodeResponse); } message RegisterRequest { string address = 1; } message RegisterResponse {} service Register { rpc Register(RegisterRequest) returns (RegisterResponse); }

在用上面的命令编译后,得到俩文件:

注意,如果别的包要引入该包下的结构体和函数,要这么显式引用:import pb "Misaka/proto"

提示

注意,给 RPC 服务和函数命名的时候,尽量别用 Register,否则就会出现类似RegisterRegisterServer的糟糕命名函数。而且 grpc 生成的函数是没有注释的,看着更难绷了。

上面的 proto 文件,定义了两个服务:一个 Register 服务,一个 ExecCode 服务。对于 Register 服务来说,MisakaOJ 这边是服务端,但是对于 ExecCode 服务来说,它又变成了客户端。

3.2 Register 服务的两端实现

先说 MisakaOJ 这边的 Register 服务。

go
// 这一侧当服务端 接收另一侧当客户端发起的注册 RPC 调用 type RegisterSer struct { pb.UnimplementedRegisterServer } // Register 这一侧当服务端的 Register 具体实现 func (r RegisterSer) Register(c context.Context, request *pb.RegisterRequest) (*pb.RegisterResponse, error) { conn, e := grpc.Dial( "localhost"+request.Address, grpc.WithInsecure(), // 启用不安全的连接 grpc.WithBlock(), // 阻塞协程直到连接建立 ) if e != nil { return nil, e } client := pb.NewExecCodeClient(conn) availableClientChan <- &ClientWithConn{ client: client, conn: conn, } return &pb.RegisterResponse{}, nil } // RegisterServerInit 初始化并且监听注册服务 func RegisterServerInit() error { listen, e := net.Listen("tcp", constants.RegisterRemoteServerPort) if e != nil { return e } grpcServer := grpc.NewServer() pb.RegisterRegisterServer(grpcServer, &RegisterSer{}) // todo 自定义一个 Logger 定义一个全局的 channel 这样所有的错误都可以放进来 go func() { _ = grpcServer.Serve(listen) }() return nil }

可以看到,grpc 成为服务端有三个步骤:

  1. 找到对应的接口。在这个实例中是UnimplementedRegisterServer
  2. 实现对应的服务端函数,在这个实例中是Register函数。
  3. 监听某一端口,实例化服务端,在这个实例中是RegisterSer结构体。通过Register_Server函数注册该结构体实例,最后监听端口。

再看客户端:

go
// Submit 到对面的注册 func Submit() error { conn, e := grpc.Dial( "localhost"+serverPort, grpc.WithInsecure(), // 启用不安全的连接 grpc.WithBlock(), // 阻塞协程直到连接建立 ) if e != nil { return e } registerClient := pb.NewRegisterClient(conn) _, e = registerClient.Register(context.Background(), &pb.RegisterRequest{Address: selfServerPort}) return e }

grpc 成为客户端也有三个步骤:

提示

grpc.Dial函数被标识为已弃用,新的函数是grpc.NewServer。但是该函数在2024-09-04这个时间对于某些DialOption,比如grpc.WithBlock(),还没有支持,所以只能先用着。

  1. 通过grpc.Dial函数建立连接。
  2. 获取客户端实例。
  3. 通过客户端实例调用 RPC 函数。

3.3 ExecCode 服务的客户端实现

go
type ExecCodeCallbackFunc func(response *pb.ExecCodeResponse) type ExecCodeRequestWithCallBack struct { request *pb.ExecCodeRequest callback ExecCodeCallbackFunc } type ClientWithConn struct { client pb.ExecCodeClient conn *grpc.ClientConn // 保存这个 *grpc.ClientConn 是为了方便我手动结束连接 } var execCodeRequestChan = make(chan *ExecCodeRequestWithCallBack, 50) // 所有远程执行代码的请求都经过这里 var availableClientChan = make(chan *ClientWithConn, 10) // 客户端连接用完放到这里 func HandleExecCodeRemoteRequest() { for { requestWithCallback := <-execCodeRequestChan client := <-availableClientChan go func(req *ExecCodeRequestWithCallBack, client *ClientWithConn) { // todo 自定义log response, e := client.client.ExecCodeRemote(context.Background(), requestWithCallback.request) if e != nil { log.Println(response.String() + " " + e.Error()) } else { log.Println(response.String()) } // 错误处理 规定如果服务端执行代码过程中的所有错误全部定义为 Internal if e != nil { // 这个 status 和 code 都是 grpc 包下的 // 这个 status 是由 code message(string) 和 detail([]any) 组成的 标识一个错误和具体的错误信息 // 这些 code 也是 grpc 预先定义好的 注意 有些 code 是框架自己因为种种原因自己生成的 但是有些 code 是服务端的实现返回的 // 比如 codes.Unavailable 如果是服务端实现中使用这个错误码 对于客户端来说应当重新发起请求重试 // 但是如果是 grpc 框架生成的 就说明是网络中断或者服务端进程中断导致无法连接到服务端 // 在远程执行代码这个情境下 我的所有的错误都会写进 response 中 不会写到 error 中 // 所以一旦出问题 一定是 grpc 框架本身报出来的错误 if st, ok := status.FromError(e); ok { execCodeRequestChan <- req switch st.Code() { case codes.Unavailable: // 无法连接至服务端 处理方法是把这个 request 放回到 channel 中重新排队 // 但是服务端不可用 就关闭这个客户端 等待客户端自动重连 最后结束这个协程 _ = client.conn.Close() return case codes.DeadlineExceeded: // 连接超时 处理方法是把这个 request 放回到 channel 中重新排队 客户端也是 availableClientChan <- client return case codes.Unknown: // 未知错误 或者错误没有足够信息 处理方式同 Unavailable _ = client.conn.Close() return } } } // 回调 让 ExecCodeRemote 获取到结果 req.callback(response) // 客户端用完回去接着排队 availableClientChan <- client }(requestWithCallback, client) } } // ExecCodeRemote 对别的包的接口 在远程跑所有的测试用例 func ExecCodeRemote(codeBytes []byte, testCases []*models.TestCase, questionMaxTime int, questionMaxMemory int64) *Result { result := &Result{} wg := sync.WaitGroup{} resultHandleFunc := func(response *pb.ExecCodeResponse) { if response.ResultStatus != 1 { result.Status = CodeResultStatus(response.ResultStatus) result.Message = response.Message result.CaseIndex = int(response.TestCaseIndex) } wg.Done() } for i := range testCases { execCodeRequestChan <- &ExecCodeRequestWithCallBack{ request: &pb.ExecCodeRequest{ Input: testCases[i].Input, Output: testCases[i].Output, Code: string(codeBytes), MaxMem: int32(questionMaxMemory), MaxTime: int32(questionMaxTime), TestCaseIndex: int32(i), }, callback: resultHandleFunc, } wg.Add(1) } wg.Wait() return result }

在客户端(实质上是 MisakaOJ 这一侧作为客户端发起请求)这边,要重点说明的有这么几个问题:

  1. 因为 grpc 是基于 http2 的,所以某些特质上有点像 http,比如错误码这个东西。在 http 中,如果一个请求返回了404,既可以说明没有连接到服务端,也可以说服务端没有找到对应请求的资源。在 grpc 中也是这样,如果调用 rpc 函数返回了错误,错误状态是codes.Unavailable的话,如果是服务端实现中使用这个错误码,对于客户端来说应当重新发起请求重试;但是如果是 grpc 框架生成的,就说明是网络中断或者服务端进程中断导致无法连接到服务端。所以我在服务端实现处避免了我自己生成 grpc 给定的这些错误状态,全部写进了我的响应中。同时,看情况决定是否放弃这个出错的客户端。
  2. 我一开始写的时候想的是实现一个服务端池,每次有需要调用的时候,拿一个空闲的服务端进行调用。后来意识到,既然服务调用完和接受一个新的连接都能产生一个空闲的服务端,而空闲服务端只有在我调用的时候使用,这就是一个多生产者一个消费者,用 channel 存起来够用了。
  3. 一个 client 实例是不能被主动关闭的,必须要调用*grpc.ClientConnClose函数才行。
  4. ExecCodeCallbackFunc类型,实质上是一个回调函数,让ExecCodeRemote拿到数据用的。

3.4 ExecCode 服务的服务端实现

go
var ( serverPort = ":22332" selfServerPort = ":22557" CodeSavePath = "./code/" ) const ( Finish int32 = iota WrongAnswer OutOfMemory OutOfTime CompileError ExecError UnexpectedExecError ) type ExecCodeServer struct { pb.UnimplementedExecCodeServer } func (ec *ExecCodeServer) ExecCodeRemote(ctx context.Context, request *pb.ExecCodeRequest) (*pb.ExecCodeResponse, error) { var e error response := &pb.ExecCodeResponse{} log.Println(request.String()) // 先保存代码 filePath := CodeSavePath + uuid.NewString() e = os.Mkdir(filePath, os.ModePerm) // 777权限 if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex return response, nil } filePath += "/main.go" file, e := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, os.ModePerm) if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex return response, nil } _, e = file.WriteString(request.Code) if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex return response, nil } contextWithTimeout, _ := context.WithTimeout(context.Background(), time.Duration(request.MaxTime)*time.Millisecond) cmd := exec.CommandContext(contextWithTimeout, "go", "run", filePath) // 构建执行命令 var out, err bytes.Buffer var inPipe io.WriteCloser // 拿到对应的输入 输出 报错信息 inPipe, e = cmd.StdinPipe() if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex return response, nil } cmd.Stdout = &out cmd.Stderr = &err // 输入 _, e = io.WriteString(inPipe, request.Input) if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex return response, nil } cmdChannel := make(chan error, 1) defer close(cmdChannel) // 这块我始终都想着是用完就关 但是因为在 Windows 下 我实在无法保证进程超时就能被立即关上从而导致 cmd.Wait 会向一个已经 close 的 channel 写东西 go func() { defer func() { _ = recover() return }() _ = cmd.Start() cmdChannel <- cmd.Wait() //cmdChannel <- cmd.Run() }() memInfo := make(chan *pb.ExecCodeResponse, 1) go func() { for { if cmd.Process != nil { break } time.Sleep(1 * time.Millisecond) // attention 怀疑在某些情况下 程序会先于 cmd.Run 跑到这里 引入 time.Sleep 就是为了防止这两个进程死锁 } memCost, e := GetPeakWorkingSetByPid(cmd.Process.Pid) if e != nil { response.ResultStatus = UnexpectedExecError response.Message = e.Error() response.TestCaseIndex = request.TestCaseIndex memInfo <- response } if int32(memCost) > request.MaxMem { response.ResultStatus = OutOfMemory response.Message = "" response.TestCaseIndex = request.TestCaseIndex memInfo <- response } }() select { case <-time.After(time.Duration(request.MaxTime) * time.Millisecond): // 超时 关闭子进程 //_ = cmd.Process.Kill() exec.Command("taskkill", "/PID", fmt.Sprint(cmd.Process.Pid), "/F").Run() response.ResultStatus = OutOfTime response.Message = "" response.TestCaseIndex = request.TestCaseIndex return response, nil case e = <-cmdChannel: // 正常结束 case response = <-memInfo: // 内存信息报错 exec.Command("taskkill", "/PID", fmt.Sprint(cmd.Process.Pid), "/F").Run() return response, nil } if e != nil { errString := err.String() // 如果是编译错误 那么一般会有 # command-line-arguments 前缀 并且退出代码为一般为1 // 如果是 panic 那么一般报错中会有 panic 并且退出代码为一般为2 // 注意 看退出码实在不靠谱 还是字符串检测吧 if strings.Contains(errString, "# command-line-arguments") { response.ResultStatus = CompileError response.Message = errString response.TestCaseIndex = request.TestCaseIndex return response, nil } else if strings.Contains(errString, "panic") { response.ResultStatus = ExecError response.Message = errString response.TestCaseIndex = request.TestCaseIndex return response, nil } else { response.ResultStatus = UnexpectedExecError response.Message = errString response.TestCaseIndex = request.TestCaseIndex return response, nil } } if out.String() != request.Output { response.ResultStatus = WrongAnswer response.Message = "" response.TestCaseIndex = request.TestCaseIndex return response, nil } response.ResultStatus = Finish response.Message = "" response.TestCaseIndex = request.TestCaseIndex return response, nil } // GetPeakWorkingSetByPid 按照pid寻找进程 获取该进程的最高内存占用 该函数耗时约300毫秒 func GetPeakWorkingSetByPid(pid int) (int64, error) { cmd := exec.Command("powershell", "-Command", fmt.Sprintf("Get-Process -Id %d | Select-Object -Property PeakWorkingSet64", pid)) output, e := cmd.Output() if e != nil { return 0, e } lines := strings.Split(string(output), "\r") if len(lines) < 3 { return 0, fmt.Errorf("unexpected output from powershell") } // 获取到的结果类似这样 // PeakWorkingSet64 // -------------- // 4025053184 result, e := strconv.ParseInt(strings.TrimSpace(lines[3]), 10, 64) return result, e } // ExecCodeServerInit 初始化接收 ExecCode 的 RPC 调用的 Server func ExecCodeServerInit() error { listen, e := net.Listen("tcp", selfServerPort) if e != nil { return e } grpcServer := grpc.NewServer() pb.RegisterExecCodeServer(grpcServer, &ExecCodeServer{}) go func() { _ = grpcServer.Serve(listen) }() return nil } func main() { var e error for { e = ExecCodeServerInit() if e == nil { break } else { log.Println("ExecCodeServerInit: " + e.Error()) return } } for { e = Submit() if e == nil { break } else { log.Println("Submit: " + e.Error()) return } } select {} }

这块也有东西要说一下,是之前写 ExecCode 没意识到的问题:

  1. 在获取进程内存信息的时候,我用了一个循环判断是否进程是否为空。但是在某些特殊情况下,程序会比cmd.Start先进循环,无限循环不释放这个cmd.Process,直接造成死锁。所以在这里加了一个time.Sleep函数中断一下。
  2. 因为cmd.Kill是一个非阻塞函数,所以有可能出现这么一个情况:cmd 成功开始了,但是超内存了,cmd.Kill执行完了并没有立即结束进程,反而是ExecCodeRemote函数先结束,defer 顺带把 cmdChannel 也给关上了。关上之后cmd.Wait也结束,在试图把值写进已经关闭的 cmdChannel 时引发 panic。为了避免这个问题,使用了 taskkill 命令,并且给那个协程加了一个没有实际作用,仅仅是为了避免 panic 的 recover 函数。因为使用 taskkill 命令依然不能完全避免上述情况的发生。

总的来说,这个方案实在是漏洞百出,以后看看要么走 CGO 的路子,要么传进来的代码是一个函数而不是一个 main.go,在 main 函数里面加 pprof 试试。

本文作者:御坂19327号

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!