通用技术 记一次基于 AST 完成的代码注入场景

流浪豆 · 2023年03月15日 · 最后由 流浪豆 回复于 2023年03月16日 · 5395 次阅读

背景

某业务团队提出脚本性能测试的需求:想知道脚本 A 中的耗时具体在哪里。
PS:刚拿到该需求的时候我哭笑不得,从需求来看就是普通的跑一次脚本获取具体的耗时在哪里😂(后续性能调优也还是研发,测试仅做了个执行脚本的操作)。。。但迫于刚把测试团队的性能测试推广出去(热乎的新活),总不能直接把自己的路堵死了吧。
图1


解决方案

简单扫了一遍脚本的调用, 整个脚本还是比较简单(也就涉及到 6、7 个方法), 这样直接手动添加耗时统计效率最快。
【方案】:人为在每个方法中添加耗时的统计方法,然后重新编译并执行。 这样就得到了具体脚本耗时最大头是在方法 XX 上, 并通过pprof 工具对结果进行了一些分析,最终将结果反馈给了研发。

事情进阶

本以为这个事情就结束了,后来业务团队又提出了更多的脚本性能需求(复杂的涉及到好几十个方法调用 )。。。
图2

一个个手工去添加重复性又太高,并且效率比较低, 原有的手工去添加耗时的方法就行不通了😓。由于项目没有接入 APM 这类自动获取调用链路以及耗时的工具(临时去接这种工具可行性不高),那怎么去快速的在这么多方法中注入耗时统计的代码呢 ( ̄ε  ̄) ?

此时想到了原来接触的静态扫描中的一一AST能够满足自动注入代码的需求。


优化版方案

通过对脚本分析, 脚本调用的方法都是在当前项目内, 所以新的方案只需要满足对指定目录/文件的方法进行耗时统计代码的注入即可。二话不说,撸起袖子就是干!此处不详细介绍 AST 的基础知识,有兴趣可以看网上一些资料(个人感觉都就只简单讲了开始的一点入门,没有什么具体的深入了(不同语言生成的语法树差异不大,比如:go、js、java 等):

PS:下面是 Demo 代码, 有兴趣的可以自己本地 GoLand 编译器上尝试一下(⚠️ 注意:Demo 版本代码比较粗糙,不要在意一些编程规范或实现 ( ̄ε  ̄))。

demo.go 文件(被注入的对象)

package main

import (
    "fmt"
)

func greet() {
  // test 一般代码首行存在注释
    total := 0
    for i := 1; i <= 1000; i++ {
        total += i
    }
    fmt.Printf("total = %v\n", total)
}

func test() {}

核心实现的代码 main.go

package main

import (
    "bufio"
    _ "bufio"
    "bytes"
    "fmt"
    "go/ast"
    "go/format"
    "go/parser"
    "go/token"
    "log"
    "os"
    _ "os"
    "path/filepath"
    "strconv"
)

type Visitor struct {
}

func addStartTimeCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        startCodePos := bodyStmt.Pos()
        xValue := &ast.Ident{
            Name: "time",
            NamePos: startCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Now",
            NamePos: startCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Lparen: startCodePos,
            Rparen: startCodePos,
        }
        assignStmt := &ast.AssignStmt{
            Lhs:    []ast.Expr {ast.NewIdent("startT")},
            TokPos: startCodePos,
            Tok: token.DEFINE,
            Rhs: []ast.Expr { callExpr },
        }
        listNew := []ast.Stmt {assignStmt}
        bodyStmt.List = append(listNew, listStmt...)
    }
}

func parseFuncEndPos(blockStmt *ast.BlockStmt) token.Pos {
    bodyList := blockStmt.List
    lastElement := bodyList[len(bodyList)-1]
    funcEndPos := lastElement.End()
    return funcEndPos
}
func addEndTimeCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        endCodePos := parseFuncEndPos(bodyStmt)
        xValue := &ast.Ident{
            Name: "time",
            NamePos: endCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Since",
            NamePos: endCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Args: []ast.Expr {ast.NewIdent("startT")},
            Lparen: endCodePos,
            Rparen: endCodePos,
        }
        assignStmt := &ast.AssignStmt{
            Lhs:    []ast.Expr {ast.NewIdent("tc")},
            TokPos: endCodePos,
            Tok: token.DEFINE,
            Rhs: []ast.Expr { callExpr},
        }
        bodyStmt.List = append(listStmt, assignStmt)
    }
}

func addEndPrintCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        endCodePos := parseFuncEndPos(bodyStmt)
        xValue := &ast.Ident{
            Name: "fmt",
            NamePos: endCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Printf",
            NamePos: endCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        basicList := &ast.BasicLit{
            ValuePos: endCodePos,
            Kind:     token.STRING,
            Value:    "\"time cost = %v\\n\"",
        }
        nameValue := &ast.Ident{
            Name: "tc",
            NamePos: endCodePos,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Args: []ast.Expr {basicList, nameValue},
            Lparen: endCodePos,
            Rparen: endCodePos,
        }
        exprStmt := &ast.ExprStmt{callExpr}
        bodyStmt.List = append(listStmt, exprStmt)
    }
}

func (v *Visitor) Visit(node ast.Node) ast.Visitor {
    switch node.(type) {
    case *ast.GenDecl:
        genDecl := node.(*ast.GenDecl)
        // 查找有没有import time包
        if genDecl.Tok == token.IMPORT {
            v.addImport(genDecl, "time")
            v.addImport(genDecl, "fmt")
            // 不需要再遍历子树
            return nil
        }
    case *ast.FuncDecl:
        funcDecl := node.(*ast.FuncDecl)
        addStartTimeCode(funcDecl)
        addEndTimeCode(funcDecl)
        addEndPrintCode(funcDecl)
        fmt.Println(funcDecl)
    }
    return v
}

// addImport 引入time包
func (v *Visitor) addImport(genDecl *ast.GenDecl, importName string) {
    // 是否已经import
    hasImported := false
    for _, v := range genDecl.Specs {
        imptSpec := v.(*ast.ImportSpec)
        // 如果已经包含"context"
        if imptSpec.Path.Value == strconv.Quote(importName) {
            hasImported = true
        }
    }
    // 如果没有import time,则import
    if !hasImported {
        genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
            Path: &ast.BasicLit{
                Kind:  token.STRING,
                Value: strconv.Quote(importName),
            },
        })
    }
}

func main() {
    fset := token.NewFileSet()
    path, _ := filepath.Abs("./demo.go")
    f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
    if err != nil {
        log.Println(err)
        return
    }
    // 遍历节点
    v := &Visitor{}
    ast.Walk(v, f)
    var output []byte
    buffer := bytes.NewBuffer(output)
    err = format.Node(buffer, fset, f)
    if err != nil {
        log.Fatal(err)
    }
    // 输出Go代码 到对应文件
    fmt.Println(buffer.String())
    pathNew, _ := filepath.Abs("./demo_new.go")
    file, err := os.OpenFile(pathNew, os.O_WRONLY|os.O_CREATE, 0666)
    defer file.Close()
    write := bufio.NewWriter(file)
    write.WriteString(buffer.String())
    write.Flush()
}

实现的效果满足需求,具体效果如下(仅对 greet 方法插入了耗时统计的代码):

package main

import (
    "fmt"
    "time"
)

// main方法
func greet() {
    startT := time.Now()
    // test
    total := 0
    for i := 1; i <= 1000; i++ {
        total += i
    }
    fmt.Printf("total = %v\n", total)
    tc := time.Since(startT)
    fmt.Printf("time cost = %v\n", tc)
}

func test() {}

当 Demo 完成后第一感觉就是 NiuBility, 一般这种情况不出意外的话马上就要出意外了😂。对真实的项目脚本进行实操时直接各种问题。真是:一顿操作猛如虎,实操却是二百五。。。
图3


问题汇总

  • 在 demo 中上来就注入了 import time 包, 但通过对文件扫描后发现没有能注入的,此时引入的 time 包未被使用导致编译失败;
  • time cost = ... 耗时的代码会在最后一行之后注入, 但实际代码最后一般都是 return ,结果可想而知;
  • 方法中存在多个 return, 当前只对最后一个进行了注入;
  • 有的文件中都没有 import 代码块,需要整个注入。
  • ...

通过 2 天坚持不懈的打补丁和兼容, 完整版总算完成了。。。(此处就先不贴具体代码,后续实用一段时间稳定后再贴出来

图4


思考

  • 效率提升:通过 AST、go-callvis 等工具获取整个脚本调用链路, 再针对的对调用链路上的每个方法进行注入,而不是笼统的对整个目录或者单个文件的所有方法进行耗时统计代码注入。
  • 易用性提升:开发一个 Web 界面, 输入脚本 gitlab 的代码仓库地址以及脚本的相对路径则可完成注入。而不是仅支持个人本地执行;

最后感慨下:真的羡慕那些过目不忘的人(自己这昨天晚上吃了什么第二天早上可能就忘记了), 好久没有看 AST 了,导致此次使用上也生疏很多😂。。。

精通 AST 这真的是一个比较强大的技能, 当前能想到的场景包括调用链路分析、代码注入、变异测试中的代码篡改等等,直接对类似 sonar 扫描中的原理看有点上手太难。

!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!
!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!
!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!

共收到 3 条回复 时间 点赞

一些基础的 ast 教程(不同语言差异不大,比如:go、js、java 等)

java 的话我一般选择 arthas 的 trace 命令

小狄子 回复

看了下相关教程 https://arthas.aliyun.com/doc/trace.html 可以满足链路的调用, 有 java 项目时我深入使用看看。

需要 登录 后方可回复, 如果你还没有账号请点击这里 注册