提交 57b3897e authored 作者: YunaiV's avatar YunaiV

1. 修复 data-permission 单元测试的报错

上级 8c4332f4
......@@ -18,7 +18,6 @@ import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
......@@ -37,7 +36,7 @@ import java.util.concurrent.ConcurrentHashMap;
/**
* 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
* 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, Table)} 方法
* 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
*
* 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
* 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
......@@ -53,8 +52,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter,
RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
......@@ -68,12 +66,14 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
// 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
@Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景
@Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
......@@ -92,7 +92,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
// 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
......@@ -107,24 +109,6 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
}
}
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodys = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodys)) {
selectBodys.forEach(this::processSelectBody);
}
}
}
/**
* update 语句处理
*/
......@@ -142,28 +126,77 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
}
// ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(this::processSelectBody);
}
}
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
plainSelect.setWhere(builderExpression(where, fromTable));
} else {
processFromItem(fromItem);
}
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(joins);
mainTables = processJoins(mainTables, joins);
}
// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}
/**
......@@ -191,7 +224,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
return;
}
if (where instanceof FromItem) {
processFromItem((FromItem) where);
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
......@@ -204,9 +237,9 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
} else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
ItemsList itemsList = expression.getRightItemsList();
if (itemsList instanceof SubSelect) {
processSelectBody(((SubSelect) itemsList).getSelectBody());
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof SubSelect) {
processSelectBody(((SubSelect) inExpression).getSelectBody());
}
} else if (where instanceof ExistsExpression) {
// exists
......@@ -239,7 +272,7 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function 函数
* @param function
*/
protected void processFunction(Function function) {
ExpressionList parameters = function.getParameters();
......@@ -257,22 +290,19 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
/**
* 处理子查询等
*/
protected void processFromItem(FromItem fromItem) {
if (fromItem instanceof SubJoin) {
SubJoin subJoin = (SubJoin) fromItem;
if (subJoin.getJoinList() != null) {
processJoins(subJoin.getJoinList());
}
if (subJoin.getLeft() != null) {
processFromItem(subJoin.getLeft());
}
} else if (fromItem instanceof SubSelect) {
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
} else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subquery, if you do not give us feedback");
logger.debug("Perform a subQuery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) {
......@@ -284,75 +314,176 @@ public class DataPermissionDatabaseInterceptor extends JsqlParserSupport impleme
}
}
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}
/**
* 处理 joins
*
* @param joins join 集合
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private void processJoins(List<Join> joins) {
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables == null) {
mainTables = new ArrayList<>();
} else if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<Table> tables = new LinkedList<>();
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem fromItem = join.getRightItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表,subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
onTables = Collections.singletonList(joinTable);
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1) {
processJoin(join);
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
tables.push(fromTable);
// 表名压栈,忽略的表压入 null,以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
Table currentTable = tables.poll();
onExpressions.add(builderExpression(originOnExpression, currentTable));
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
// 处理右边连接的子表达式
processFromItem(fromItem);
processOtherFromItem(joinItem);
leftTable = null;
}
}
return mainTables;
}
// ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
/**
* 处理联接语句
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param table 单个表
*/
protected void processJoin(Join join) {
if (join.getRightItem() instanceof Table) {
Table fromTable = (Table) join.getRightItem();
Expression originOnExpression = CollUtil.getFirst(join.getOnExpressions());
originOnExpression = builderExpression(originOnExpression, fromTable);
join.setOnExpressions(CollUtil.newArrayList(originOnExpression));
}
protected Expression builderExpression(Expression currentExpression, Table table) {
return this.builderExpression(currentExpression, Collections.singletonList(table));
}
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param tables 多个表
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
// 获得 Table 对应的数据权限条件
Expression equalsTo = buildDataPermissionExpression(table);
if (equalsTo == null) { // 如果没条件,则返回 currentExpression 默认
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 表达式为空,则直接返回 equalsTo
// 第一步,获得 Table 对应的数据权限条件
Expression dataPermissionExpression = null;
for (Table table : tables) {
// 构建每个表的权限 Expression 条件
Expression expression = buildDataPermissionExpression(table);
if (expression == null) {
continue;
}
// 合并到 dataPermissionExpression 中
dataPermissionExpression = dataPermissionExpression == null ? expression
: new AndExpression(dataPermissionExpression, expression);
}
// 第二步,合并多个 Expression 条件
if (dataPermissionExpression == null) {
return currentExpression;
}
if (currentExpression == null) {
return equalsTo;
return dataPermissionExpression;
}
// 如果表达式为 Or,则需要 (currentExpression) AND equalsTo
// ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), equalsTo);
return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
}
// 如果表达式为 And,则直接返回 currentExpression AND equalsTo
return new AndExpression(currentExpression, equalsTo);
// ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression
return new AndExpression(currentExpression, dataPermissionExpression);
}
/**
......
......@@ -87,7 +87,7 @@ public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
verify(mpBs, times(1)).sql(
eq("SELECT * FROM t_user WHERE id = 1 AND dept_id = 100"));
eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
// 断言缓存
assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
}
......
......@@ -46,7 +46,7 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Override
public Set<String> getTableNames() {
return asSet("entity", "entity1", "entity2", "t1", "t2", // 支持 MyBatis Plus 的单元测试
return asSet("entity", "entity1", "entity2", "entity3", "t1", "t2", "sys_dict_item", // 支持 MyBatis Plus 的单元测试
"t_user", "t_role"); // 满足自己的单元测试
}
......@@ -84,30 +84,30 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Test
void delete() {
assertSql("delete from entity where id = ?",
"DELETE FROM entity WHERE id = ? AND tenant_id = 1");
"DELETE FROM entity WHERE id = ? AND entity.tenant_id = 1");
}
@Test
void update() {
assertSql("update entity set name = ? where id = ?",
"UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1");
"UPDATE entity SET name = ? WHERE id = ? AND entity.tenant_id = 1");
}
@Test
void selectSingle() {
// 单表
assertSql("select * from entity where id = ?",
"SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
"SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1");
assertSql("select * from entity where id = ? or name = ?",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
/* not */
assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1");
}
@Test
......@@ -167,10 +167,12 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <= */
assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <> */
assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
......@@ -204,6 +206,14 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
}
@Test
......@@ -212,17 +222,125 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM with_as_1 e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM with_as_1 e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id ",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
}
@Test
void selectMixJoin() {
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"inner join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
}
@Test
void selectJoinSubSelect() {
assertSql("select * from (select * from entity) e1 " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
assertSql("select * from entity1 e1 " +
"left join (select * from entity2) e2 " +
"on e1.id = e2.id",
"SELECT * FROM entity1 e1 " +
"LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " +
"ON e1.id = e2.id " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectSubJoin() {
assertSql("select * FROM " +
"(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"WHERE e2.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"WHERE e1.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
"right join entity3 e3 on e1.id = e3.id",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
"WHERE e3.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"ON e.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectLeftJoinMultipleTrailingOn() {
// 多个 on 尾缀的
......@@ -256,51 +374,97 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE e.id = ? OR e.name = ?");
assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// 隐式内连接
assertSql("SELECT * FROM entity,entity1 " +
"WHERE entity.id = entity1.id",
"SELECT * FROM entity, entity1 " +
"WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// 隐式内连接
assertSql("SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id AND a.tenant_id = 1");
assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id");
// SubJoin with 隐式内连接
assertSql("SELECT * FROM (entity,entity1) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (entity, entity1) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
assertSql("SELECT * FROM ((entity,entity1),entity2) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM ((entity, entity1), entity2) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
assertSql("SELECT * FROM (entity,(entity1,entity2)) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM (entity, (entity1, entity2)) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
// 沙雕的括号写法
assertSql("SELECT * FROM (((entity,entity1))) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (((entity, entity1))) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// 垃圾 inner join todo
// assertSql("SELECT * FROM entity,entity1 " +
// "WHERE entity.id = entity1.id",
// "SELECT * FROM entity e " +
// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
}
@Test
void selectWithAs() {
assertSql("with with_as_A as (select * from entity) select * from with_as_A",
"WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A");
"WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A");
}
@Test
void selectIgnoreTable() {
assertSql(" SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)",
"SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id AND item.tenant_id = 1 WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)");
}
private void assertSql(String sql, String targetSql) {
assertEquals(targetSql, interceptor.parserSingle(sql, null));
}
// ========== 额外的测试 ==========
@Test
public void testSelectSingle() {
// 单表
assertSql("select * from t_user where id = ?",
"SELECT * FROM t_user WHERE id = ? AND tenant_id = 1 AND dept_id IN (10, 20)");
"SELECT * FROM t_user WHERE id = ? AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("select * from t_user where id = ? or name = ?",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)");
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("SELECT * FROM t_user WHERE (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)");
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
/* not */
assertSql("SELECT * FROM t_user WHERE not (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND tenant_id = 1 AND dept_id IN (10, 20)");
"SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
}
@Test
......@@ -329,16 +493,16 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"right join t_role e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
// 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " +
"right join t_role e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
}
@Test
......@@ -348,23 +512,22 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE e.id = ? OR e.name = ?");
// 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " +
"inner join t_role e1 on e1.id = e.id " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " +
"INNER JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
// 垃圾 inner join todo
// assertSql("SELECT * FROM entity,entity1 " +
// "WHERE entity.id = entity1.id",
// "SELECT * FROM entity e " +
// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// 没有 On 的 inner join
assertSql("SELECT * FROM entity,entity1 " +
"WHERE entity.id = entity1.id",
"SELECT * FROM entity, entity1 " +
"WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
}
}
......@@ -3,13 +3,12 @@ package cn.iocoder.yudao.framework.datapermission.core.rule.dept;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.module.system.api.permission.PermissionApi;
import cn.iocoder.yudao.module.system.api.permission.dto.DeptDataPermissionRespDTO;
import cn.iocoder.yudao.framework.security.core.LoginUser;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import cn.iocoder.yudao.module.system.api.permission.PermissionApi;
import cn.iocoder.yudao.module.system.api.permission.dto.DeptDataPermissionRespDTO;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import org.junit.jupiter.api.BeforeEach;
......@@ -25,6 +24,7 @@ import static cn.iocoder.yudao.framework.datapermission.core.rule.dept.DeptDataP
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomPojo;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomString;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
......@@ -75,6 +75,8 @@ class DeptDataPermissionRuleTest extends BaseMockitoUnitTest {
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法(permissionApi 返回 null)
when(permissionApi.getDeptDataPermission(eq(loginUser.getId()))).thenReturn(success(null));
// 调用
NullPointerException exception = assertThrows(NullPointerException.class,
......
package cn.iocoder.yudao.framework.pay.core.client.dto;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelRefundRespEnum;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
......
......@@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPcPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayQrPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayWapPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXLitePayClient;
......@@ -69,7 +70,7 @@ public class PayClientFactoryImpl implements PayClientFactory {
case ALIPAY_WAP: return (AbstractPayClient<Config>) new AlipayWapPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_QR: return (AbstractPayClient<Config>) new AlipayQrPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_APP: return (AbstractPayClient<Config>) new AlipayQrPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_PC: return (AbstractPayClient<Config>) new AlipayQrPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_PC: return (AbstractPayClient<Config>) new AlipayPcPayClient(channelId, (AlipayPayClientConfig) config);
}
// 创建失败,错误日志 + 抛出异常
log.error("[createPayClient][配置({}) 找不到合适的客户端实现]", config);
......
package cn.iocoder.yudao.framework.pay.core.client.impl.alipay;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.http.HttpUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
......@@ -61,7 +60,7 @@ public abstract class AbstractAlipayClient extends AbstractPayClient<AlipayPayCl
return PayOrderNotifyRespDTO.builder().orderExtensionNo(params.get("out_trade_no"))
.channelOrderNo(params.get("trade_no")).channelUserId(params.get("seller_id"))
.tradeStatus(params.get("trade_status"))
.successTime(DateUtil.parse(params.get("notify_time"), "yyyy-MM-dd HH:mm:ss"))
.successTime(LocalDateTimeUtil.parse(params.get("notify_time"), "yyyy-MM-dd HH:mm:ss"))
.data(data.getBody()).build();
}
......@@ -72,7 +71,7 @@ public abstract class AbstractAlipayClient extends AbstractPayClient<AlipayPayCl
.tradeNo(params.get("out_trade_no"))
.reqNo(params.get("out_biz_no"))
.status(PayNotifyRefundStatusEnum.SUCCESS)
.refundSuccessTime(DateUtil.parse(params.get("gmt_refund"), "yyyy-MM-dd HH:mm:ss"))
.refundSuccessTime(LocalDateTimeUtil.parse(params.get("gmt_refund"), "yyyy-MM-dd HH:mm:ss"))
.build();
return notifyDTO;
}
......
package pay.core.client.impl.alipay;
package cn.iocoder.yudao.framework.pay.core.client.impl.alipay;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayOrderUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AbstractAlipayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.alibaba.fastjson.JSONObject;
import com.alipay.api.AlipayApiException;
......
......@@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
......@@ -49,7 +50,7 @@ public class WXLitePayClient extends AbstractPayClient<WXPayClientConfig> {
@Override
protected void doInit() {
WxPayConfig payConfig = new WxPayConfig();
BeanUtil.copyProperties(config, payConfig, "privateKeyContent", "privateCertContent");
BeanUtil.copyProperties(config, payConfig, "keyContent");
payConfig.setTradeType(WxPayConstants.TradeType.JSAPI); // 设置使用 JS API 支付方式
// if (StrUtil.isNotEmpty(config.getKeyContent())) {
// payConfig.setKeyContent(config.getKeyContent().getBytes(StandardCharsets.UTF_8));
......@@ -167,7 +168,7 @@ public class WXLitePayClient extends AbstractPayClient<WXPayClientConfig> {
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(DateUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
......@@ -181,7 +182,7 @@ public class WXLitePayClient extends AbstractPayClient<WXPayClientConfig> {
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(DateUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
......
......@@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.io.FileUtils;
......@@ -150,7 +151,7 @@ public class WXNativePayClient extends AbstractPayClient<WXPayClientConfig> {
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(DateUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
......@@ -164,7 +165,7 @@ public class WXNativePayClient extends AbstractPayClient<WXPayClientConfig> {
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(DateUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
......
......@@ -2,6 +2,7 @@ package cn.iocoder.yudao.framework.pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
......@@ -161,7 +162,7 @@ public class WXPubPayClient extends AbstractPayClient<WXPayClientConfig> {
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(DateUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
......@@ -175,7 +176,7 @@ public class WXPubPayClient extends AbstractPayClient<WXPayClientConfig> {
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(DateUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
......
package pay.config;
import lombok.Data;
import org.hibernate.validator.constraints.URL;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.validation.annotation.Validated;
import javax.validation.constraints.NotEmpty;
@ConfigurationProperties(prefix = "yudao.pay")
@Validated
@Data
public class PayProperties {
/**
* 支付回调地址
* 注意,支付渠道统一回调到 payNotifyUrl 地址,由支付模块统一处理;然后,自己的支付模块,在回调 PayAppDO.payNotifyUrl 地址
*/
@NotEmpty(message = "支付回调地址不能为空")
@URL(message = "支付回调地址的格式必须是 URL")
private String payNotifyUrl;
/**
* 退款回调地址
* 注意点,同 {@link #payNotifyUrl} 属性
*/
@NotEmpty(message = "退款回调地址不能为空")
@URL(message = "退款回调地址的格式必须是 URL")
private String refundNotifyUrl;
/**
* 支付完成的返回地址
*/
@URL(message = "支付返回的地址的格式必须是 URL")
@NotEmpty(message = "支付返回的地址不能为空")
private String payReturnUrl;
}
package pay.config;
import cn.iocoder.yudao.framework.pay.config.PayProperties;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.impl.PayClientFactoryImpl;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* 支付配置类
*
* @author 芋道源码
*/
@Configuration
@EnableConfigurationProperties(PayProperties.class)
public class YudaoPayAutoConfiguration {
@Bean
public PayClientFactory payClientFactory() {
return new PayClientFactoryImpl();
}
}
package pay.core.client;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.enums.PayFrameworkErrorCodeConstants;
import lombok.extern.slf4j.Slf4j;
/**
* 将 API 的错误码,转换为通用的错误码
*
* @see PayCommonResult
* @see PayFrameworkErrorCodeConstants
*
* @author 芋道源码
*/
@Slf4j
public abstract class AbstractPayCodeMapping {
public final ErrorCode apply(String apiCode, String apiMsg) {
if (apiCode == null) {
log.error("[apply][API 错误码为空,请排查]");
return PayFrameworkErrorCodeConstants.EXCEPTION;
}
ErrorCode errorCode = this.apply0(apiCode, apiMsg);
if (errorCode == null) {
log.error("[apply][API 错误码({}) 错误提示({}) 无法匹配]", apiCode, apiMsg);
return PayFrameworkErrorCodeConstants.PAY_UNKNOWN;
}
return errorCode;
}
protected abstract ErrorCode apply0(String apiCode, String apiMsg);
}
package pay.core.client;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
/**
* 支付客户端,用于对接各支付渠道的 SDK,实现发起支付、退款等功能
*
* @author 芋道源码
*/
public interface PayClient {
/**
* 获得渠道编号
*
* @return 渠道编号
*/
Long getId();
/**
* 调用支付渠道,统一下单
*
* @param reqDTO 下单信息
* @return 各支付渠道的返回结果
*/
cn.iocoder.yudao.framework.pay.core.client.PayCommonResult<?> unifiedOrder(PayOrderUnifiedReqDTO reqDTO);
/**
* 解析支付单的通知结果
*
* @param data 通知结果
* @return 解析结果
* @throws Exception 解析失败,抛出异常
*/
PayOrderNotifyRespDTO parseOrderNotify(PayNotifyDataDTO data) throws Exception;
/**
* 调用支付渠道,进行退款
* @param reqDTO 统一退款请求信息
* @return 各支付渠道的统一返回结果
*/
PayCommonResult<PayRefundUnifiedRespDTO> unifiedRefund(PayRefundUnifiedReqDTO reqDTO);
/**
* 解析支付退款通知数据
* @param notifyData 支付退款通知请求数据
* @return 支付退款通知的Notify DTO
*/
PayRefundNotifyDTO parseRefundNotify(PayNotifyDataDTO notifyData);
// TODO @芋艿:后续改成非 default,避免不知道去实现
/**
* 验证是否渠道通知
*
* @param notifyData 通知数据
* @return 默认是 true
*/
default boolean verifyNotifyData(PayNotifyDataDTO notifyData) {
return true;
}
// TODO @芋艿:后续改成非 default,避免不知道去实现
/**
* 判断是否为退款通知
*
* @param notifyData 通知数据
* @return 默认是 false
*/
default boolean isRefundNotify(PayNotifyDataDTO notifyData){
return false;
}
}
package pay.core.client;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import javax.validation.ConstraintViolation;
import javax.validation.ConstraintViolationException;
import javax.validation.Validator;
import java.util.Set;
/**
* 支付客户端的配置,本质是支付渠道的配置
* 每个不同的渠道,需要不同的配置,通过子类来定义
*
* @author 芋道源码
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
// @JsonTypeInfo 注解的作用,Jackson 多态
// 1. 序列化到时数据库时,增加 @class 属性。
// 2. 反序列化到内存对象时,通过 @class 属性,可以创建出正确的类型
public interface PayClientConfig {
/**
* 配置验证参数是
*
* @param validator 校验对象
* @return 配置好的验证参数
*/
Set<ConstraintViolation<PayClientConfig>> verifyParam(Validator validator);
// TODO @aquan:貌似抽象一个 validation group 就好了!
/**
* 参数校验
*
* @param validator 校验对象
*/
default void validate(Validator validator) {
Set<ConstraintViolation<PayClientConfig>> violations = verifyParam(validator);
if (!violations.isEmpty()) {
throw new ConstraintViolationException(violations);
}
}
}
package pay.core.client;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
/**
* 支付客户端的工厂接口
*
* @author 芋道源码
*/
public interface PayClientFactory {
/**
* 获得支付客户端
*
* @param channelId 渠道编号
* @return 支付客户端
*/
PayClient getPayClient(Long channelId);
/**
* 创建支付客户端
*
* @param channelId 渠道编号
* @param channelCode 渠道编码
* @param config 支付配置
*/
<Config extends PayClientConfig> void createOrUpdatePayClient(Long channelId, String channelCode,
Config config);
}
package pay.core.client;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.lang.Assert;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.enums.PayFrameworkErrorCodeConstants;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
/**
* 支付的 CommonResult 拓展类
*
* 考虑到不同的平台,返回的 code 和 msg 是不同的,所以统一额外返回 {@link #apiCode} 和 {@link #apiMsg} 字段
*
* @author 芋道源码
*/
@Data
@EqualsAndHashCode(callSuper = true)
@ToString(callSuper = true)
public class PayCommonResult<T> extends CommonResult<T> {
/**
* API 返回错误码
*
* 由于第三方的错误码可能是字符串,所以使用 String 类型
*/
private String apiCode;
/**
* API 返回提示
*/
private String apiMsg;
private PayCommonResult() {
}
public static <T> PayCommonResult<T> build(String apiCode, String apiMsg, T data, AbstractPayCodeMapping codeMapping) {
Assert.notNull(codeMapping, "参数 codeMapping 不能为空");
PayCommonResult<T> result = new PayCommonResult<T>().setApiCode(apiCode).setApiMsg(apiMsg);
result.setData(data);
// 翻译错误码
if (codeMapping != null) {
ErrorCode errorCode = codeMapping.apply(apiCode, apiMsg);
result.setCode(errorCode.getCode()).setMsg(errorCode.getMsg());
}
return result;
}
public static <T> PayCommonResult<T> error(Throwable ex) {
PayCommonResult<T> result = new PayCommonResult<>();
result.setCode(PayFrameworkErrorCodeConstants.EXCEPTION.getCode());
result.setMsg(ExceptionUtil.getRootCauseMessage(ex));
return result;
}
}
package pay.core.client.dto;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import java.util.Map;
/**
* 支付订单,退款订单回调,渠道的统一通知请求数据
*/
@Data
@ToString
@Builder
public class PayNotifyDataDTO {
/**
* HTTP 回调接口的 request body
*/
private String body;
/**
* HTTP 回调接口 content type 为 application/x-www-form-urlencoded 的所有参数
*/
private Map<String,String> params;
}
package pay.core.client.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
/**
* 支付通知 Response DTO
*
* @author 芋道源码
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class PayOrderNotifyRespDTO {
/**
* 支付订单号(支付模块的)
*/
private String orderExtensionNo;
/**
* 支付渠道编号
*/
private String channelOrderNo;
/**
* 支付渠道用户编号
*/
private String channelUserId;
/**
* 支付成功时间
*/
private LocalDateTime successTime;
/**
* 通知的原始数据
*
* 主要用于持久化,方便后续修复数据,或者排错
*/
private String data;
/**
* TODO @jason 结合其他的渠道定义成枚举,
* alipay
* TRADE_CLOSED,未付款交易超时关闭,或支付完成后全额退款。
* TRADE_SUCCESS, 交易支付成功
* TRADE_FINISHED 交易结束,不可退款。
*/
private String tradeStatus;
}
package pay.core.client.dto;
import lombok.Data;
import org.hibernate.validator.constraints.Length;
import org.hibernate.validator.constraints.URL;
import javax.validation.constraints.DecimalMin;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
import java.time.LocalDateTime;
import java.util.Map;
/**
* 统一下单 Request DTO
*
* @author 芋道源码
*/
@Data
public class PayOrderUnifiedReqDTO {
/**
* 用户 IP
*/
@NotEmpty(message = "用户 IP 不能为空")
private String userIp;
// ========== 商户相关字段 ==========
/**
* 商户订单编号
*/
@NotEmpty(message = "商户订单编号不能为空")
private String merchantOrderId;
/**
* 商品标题
*/
@NotEmpty(message = "商品标题不能为空")
@Length(max = 32, message = "商品标题不能超过 32")
private String subject;
/**
* 商品描述信息
*/
@NotEmpty(message = "商品描述信息不能为空")
@Length(max = 128, message = "商品描述信息长度不能超过128")
private String body;
/**
* 支付结果的 notify 回调地址
*/
@NotEmpty(message = "支付结果的回调地址不能为空")
@URL(message = "支付结果的 notify 回调地址必须是 URL 格式")
private String notifyUrl;
/**
* 支付结果的 return 回调地址
*/
@URL(message = "支付结果的 return 回调地址必须是 URL 格式")
private String returnUrl;
// ========== 订单相关字段 ==========
/**
* 支付金额,单位:分
*/
@NotNull(message = "支付金额不能为空")
@DecimalMin(value = "0", inclusive = false, message = "支付金额必须大于零")
private Long amount;
/**
* 支付过期时间
*/
@NotNull(message = "支付过期时间不能为空")
private LocalDateTime expireTime;
// ========== 拓展参数 ==========
/**
* 支付渠道的额外参数
*
* 例如说,微信公众号需要传递 openid 参数
*/
private Map<String, String> channelExtras;
}
package pay.core.client.dto;
import cn.iocoder.yudao.framework.pay.core.enums.PayNotifyRefundStatusEnum;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import java.time.LocalDateTime;
/**
* 从渠道返回数据中解析得到的支付退款通知的Notify DTO
*
* @author jason
*/
@Data
@ToString
@Builder
public class PayRefundNotifyDTO {
/**
* 支付渠道编号
*/
private String channelOrderNo;
/**
* 交易订单号,根据规则生成
* 调用支付渠道时,使用该字段作为对接的订单号。
* 1. 调用微信支付 https://api.mch.weixin.qq.com/pay/unifiedorder 时,使用该字段作为 out_trade_no
* 2. 调用支付宝 https://opendocs.alipay.com/apis 时,使用该字段作为 out_trade_no
* 这里对应 pay_extension 里面的 no
* 例如说,P202110132239124200055
*/
private String tradeNo;
/**
* https://api.mch.weixin.qq.com/v3/refund/domestic/refunds 中的 out_refund_no
* https://opendocs.alipay.com/apis alipay.trade.refund 中的 out_request_no
* 退款请求号。
* 标识一次退款请求,需要保证在交易号下唯一,如需部分退款,则此参数必传。
* 注:针对同一次退款请求,如果调用接口失败或异常了,重试时需要保证退款请求号不能变更,
* 防止该笔交易重复退款。支付宝会保证同样的退款请求号多次请求只会退一次。
* 退款单请求号,根据规则生成
*
* 例如说,RR202109181134287570000
*/
private String reqNo;
/**
* 退款是否成功
*/
private PayNotifyRefundStatusEnum status;
/**
* 退款成功时间
*/
private LocalDateTime refundSuccessTime;
}
package pay.core.client.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import org.hibernate.validator.constraints.URL;
import javax.validation.constraints.DecimalMin;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
/**
* 统一 退款 Request DTO
*
* @author jason
*/
@Accessors(chain = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class PayRefundUnifiedReqDTO {
/**
* 用户 IP
*/
private String userIp;
// TODO @jason:这个是否为非必传字段呀,只需要传递 payTradeNo 字段即可。尽可能精简
/**
* https://api.mch.weixin.qq.com/v3/refund/domestic/refunds 中的 transaction_id
* https://opendocs.alipay.com/apis alipay.trade.refund 中的 trade_no
* 渠道订单号
*/
private String channelOrderNo;
/**
* https://api.mch.weixin.qq.com/v3/refund/domestic/refunds 中的 out_trade_no
* https://opendocs.alipay.com/apis alipay.trade.refund 中的 out_trade_no
* 支付交易号 {PayOrderExtensionDO no字段} 和 渠道订单号 不能同时为空
*/
private String payTradeNo;
/**
* https://api.mch.weixin.qq.com/v3/refund/domestic/refunds 中的 out_refund_no
* https://opendocs.alipay.com/apis alipay.trade.refund 中的 out_trade_no
* 退款请求单号 同一退款请求单号多次请求只退一笔。
* 使用 商户的退款单号。{PayRefundDO 字段 merchantRefundNo}
*/
@NotEmpty(message = "退款请求单号")
private String merchantRefundId;
/**
* 退款原因
*/
@NotEmpty(message = "退款原因不能为空")
private String reason;
/**
* 退款金额,单位:分
*/
@NotNull(message = "退款金额不能为空")
@DecimalMin(value = "0", inclusive = false, message = "支付金额必须大于零")
private Long amount;
/**
* 退款结果 notify 回调地址, 支付宝退款不需要回调地址, 微信需要
*/
@URL(message = "支付结果的 notify 回调地址必须是 URL 格式")
private String notifyUrl;
}
package pay.core.client.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
/**
* 统一退款 Response DTO
*
* @author jason
*/
@Accessors(chain = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class PayRefundUnifiedRespDTO {
/**
* 渠道退款单编号
*/
private String channelRefundId;
}
package pay.core.client.impl;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayOrderUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayRefundUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayRefundUnifiedRespDTO;
import lombok.extern.slf4j.Slf4j;
import javax.validation.Validation;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
/**
* 支付客户端的抽象类,提供模板方法,减少子类的冗余代码
*
* @author 芋道源码
*/
@Slf4j
public abstract class AbstractPayClient<Config extends PayClientConfig> implements PayClient {
/**
* 渠道编号
*/
private final Long channelId;
/**
* 渠道编码
*/
private final String channelCode;
/**
* 错误码枚举类
*/
protected AbstractPayCodeMapping codeMapping;
/**
* 支付配置
*/
protected Config config;
public AbstractPayClient(Long channelId, String channelCode, Config config, AbstractPayCodeMapping codeMapping) {
this.channelId = channelId;
this.channelCode = channelCode;
this.codeMapping = codeMapping;
this.config = config;
}
/**
* 初始化
*/
public final void init() {
doInit();
log.info("[init][配置({}) 初始化完成]", config);
}
/**
* 自定义初始化
*/
protected abstract void doInit();
public final void refresh(Config config) {
// 判断是否更新
if (config.equals(this.config)) {
return;
}
log.info("[refresh][配置({})发生变化,重新初始化]", config);
this.config = config;
// 初始化
this.init();
}
protected Double calculateAmount(Long amount) {
return amount / 100.0;
}
@Override
public Long getId() {
return channelId;
}
@Override
public final PayCommonResult<?> unifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
Validation.buildDefaultValidatorFactory().getValidator().validate(reqDTO);
// 执行短信发送
PayCommonResult<?> result;
try {
result = doUnifiedOrder(reqDTO);
} catch (Throwable ex) {
// 打印异常日志
log.error("[unifiedOrder][request({}) 发起支付失败]", toJsonString(reqDTO), ex);
// 封装返回
return PayCommonResult.error(ex);
}
return result;
}
protected abstract PayCommonResult<?> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO)
throws Throwable;
@Override
public PayCommonResult<PayRefundUnifiedRespDTO> unifiedRefund(PayRefundUnifiedReqDTO reqDTO) {
PayCommonResult<PayRefundUnifiedRespDTO> resp;
try {
resp = doUnifiedRefund(reqDTO);
} catch (Throwable ex) {
// 记录异常日志
log.error("[unifiedRefund][request({}) 发起退款失败]", toJsonString(reqDTO), ex);
resp = PayCommonResult.error(ex);
}
return resp;
}
protected abstract PayCommonResult<PayRefundUnifiedRespDTO> doUnifiedRefund(PayRefundUnifiedReqDTO reqDTO) throws Throwable;
}
package pay.core.client.impl;
import cn.hutool.core.lang.Assert;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPcPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayQrPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayWapPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXLitePayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXNativePayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPubPayClient;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
* 支付客户端的工厂实现类
*
* @author 芋道源码
*/
@Slf4j
public class PayClientFactoryImpl implements PayClientFactory {
/**
* 支付客户端 Map
* key:渠道编号
*/
private final ConcurrentMap<Long, cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<?>> clients = new ConcurrentHashMap<>();
@Override
public PayClient getPayClient(Long channelId) {
cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<?> client = clients.get(channelId);
if (client == null) {
log.error("[getPayClient][渠道编号({}) 找不到客户端]", channelId);
}
return client;
}
@Override
@SuppressWarnings("unchecked")
public <Config extends PayClientConfig> void createOrUpdatePayClient(Long channelId, String channelCode,
Config config) {
cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config> client = (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) clients.get(channelId);
if (client == null) {
client = this.createPayClient(channelId, channelCode, config);
client.init();
clients.put(client.getId(), client);
} else {
client.refresh(config);
}
}
@SuppressWarnings("unchecked")
private <Config extends PayClientConfig> cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config> createPayClient(
Long channelId, String channelCode, Config config) {
PayChannelEnum channelEnum = PayChannelEnum.getByCode(channelCode);
Assert.notNull(channelEnum, String.format("支付渠道(%s) 为空", channelEnum));
// 创建客户端
// TODO @芋艿 WX_LITE WX_APP 如果不添加在 项目启动的时候去初始化会报错无法启动。所以我手动加了两个,具体需要你来配
switch (channelEnum) {
case WX_PUB: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new WXPubPayClient(channelId, (WXPayClientConfig) config);
case WX_LITE: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new WXLitePayClient(channelId, (WXPayClientConfig) config); //微信小程序请求支付
case WX_APP: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new WXPubPayClient(channelId, (WXPayClientConfig) config);
case WX_NATIVE: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new WXNativePayClient(channelId, (WXPayClientConfig) config);
case ALIPAY_WAP: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new AlipayWapPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_QR: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new AlipayQrPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_APP: return (cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient<Config>) new AlipayQrPayClient(channelId, (AlipayPayClientConfig) config);
case ALIPAY_PC: return (AbstractPayClient<Config>) new AlipayPcPayClient(channelId, (AlipayPayClientConfig) config);
}
// 创建失败,错误日志 + 抛出异常
log.error("[createPayClient][配置({}) 找不到合适的客户端实现]", config);
throw new IllegalArgumentException(String.format("配置(%s) 找不到合适的客户端实现", config));
}
}
package pay.core.client.impl.alipay;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
import cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.enums.PayNotifyRefundStatusEnum;
import com.alipay.api.AlipayApiException;
import com.alipay.api.AlipayConfig;
import com.alipay.api.DefaultAlipayClient;
import com.alipay.api.domain.AlipayTradeRefundModel;
import com.alipay.api.internal.util.AlipaySignature;
import com.alipay.api.request.AlipayTradeRefundRequest;
import com.alipay.api.response.AlipayTradeRefundResponse;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
/**
* 支付宝抽象类, 实现支付宝统一的接口。如退款
*
* @author jason
*/
@Slf4j
public abstract class AbstractAlipayClient extends AbstractPayClient<cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig> {
protected DefaultAlipayClient client;
public AbstractAlipayClient(Long channelId, String channelCode,
AlipayPayClientConfig config, AbstractPayCodeMapping codeMapping) {
super(channelId, channelCode, config, codeMapping);
}
@Override
@SneakyThrows
protected void doInit() {
AlipayConfig alipayConfig = new AlipayConfig();
BeanUtil.copyProperties(config, alipayConfig, false);
this.client = new DefaultAlipayClient(alipayConfig);
}
/**
* 从支付宝通知返回参数中解析 PayOrderNotifyRespDTO, 通知具体参数参考
* //https://opendocs.alipay.com/open/203/105286
* @param data 通知结果
* @return 解析结果 PayOrderNotifyRespDTO
* @throws Exception 解析失败,抛出异常
*/
@Override
public PayOrderNotifyRespDTO parseOrderNotify(PayNotifyDataDTO data) throws Exception {
Map<String, String> params = strToMap(data.getBody());
return PayOrderNotifyRespDTO.builder().orderExtensionNo(params.get("out_trade_no"))
.channelOrderNo(params.get("trade_no")).channelUserId(params.get("seller_id"))
.tradeStatus(params.get("trade_status"))
.successTime(LocalDateTimeUtil.parse(params.get("notify_time"), "yyyy-MM-dd HH:mm:ss"))
.data(data.getBody()).build();
}
@Override
public PayRefundNotifyDTO parseRefundNotify(PayNotifyDataDTO notifyData) {
Map<String, String> params = strToMap(notifyData.getBody());
PayRefundNotifyDTO notifyDTO = PayRefundNotifyDTO.builder().channelOrderNo(params.get("trade_no"))
.tradeNo(params.get("out_trade_no"))
.reqNo(params.get("out_biz_no"))
.status(PayNotifyRefundStatusEnum.SUCCESS)
.refundSuccessTime(LocalDateTimeUtil.parse(params.get("gmt_refund"), "yyyy-MM-dd HH:mm:ss"))
.build();
return notifyDTO;
}
@Override
public boolean isRefundNotify(PayNotifyDataDTO notifyData) {
if (notifyData.getParams().containsKey("refund_fee")) {
return true;
} else {
return false;
}
}
@Override
public boolean verifyNotifyData(PayNotifyDataDTO notifyData) {
boolean verifyResult = false;
try {
verifyResult = AlipaySignature.rsaCheckV1(notifyData.getParams(), config.getAlipayPublicKey(), StandardCharsets.UTF_8.name(), "RSA2");
} catch (AlipayApiException e) {
log.error("[AlipayClient verifyNotifyData][(notify param is :{}) 验证失败]", toJsonString(notifyData.getParams()), e);
}
return verifyResult;
}
/**
* 支付宝统一的退款接口 alipay.trade.refund
* @param reqDTO 退款请求 request DTO
* @return 退款请求 Response
*/
@Override
protected PayCommonResult<PayRefundUnifiedRespDTO> doUnifiedRefund(PayRefundUnifiedReqDTO reqDTO) {
AlipayTradeRefundModel model=new AlipayTradeRefundModel();
model.setTradeNo(reqDTO.getChannelOrderNo());
model.setOutTradeNo(reqDTO.getPayTradeNo());
model.setOutRequestNo(reqDTO.getMerchantRefundId());
model.setRefundAmount(calculateAmount(reqDTO.getAmount()).toString());
model.setRefundReason(reqDTO.getReason());
AlipayTradeRefundRequest refundRequest = new AlipayTradeRefundRequest();
refundRequest.setBizModel(model);
try {
AlipayTradeRefundResponse response = client.execute(refundRequest);
log.info("[doUnifiedRefund][response({}) 发起退款 渠道返回", toJsonString(response));
if (response.isSuccess()) {
//退款导致触发的异步通知是发送到支付接口中设置的notify_url
//支付宝不返回退款单号,设置为空
PayRefundUnifiedRespDTO respDTO = new PayRefundUnifiedRespDTO();
respDTO.setChannelRefundId("");
return PayCommonResult.build(response.getCode(), response.getMsg(), respDTO, codeMapping);
}
// 失败。需要抛出异常
return PayCommonResult.build(response.getCode(), response.getMsg(), null, codeMapping);
} catch (AlipayApiException e) {
// TODO 记录异常日志
log.error("[doUnifiedRefund][request({}) 发起退款失败,网络读超时,退款状态未知]", toJsonString(reqDTO), e);
return PayCommonResult.build(e.getErrCode(), e.getErrMsg(), null, codeMapping);
}
}
/**
* 支付宝统一回调参数 str 转 map
*
* @param s 支付宝支付通知回调参数
* @return map 支付宝集合
*/
public static Map<String, String> strToMap(String s) {
// TODO @zxy:这个可以使用 hutool 的 HttpUtil decodeParams 方法么?
Map<String, String> stringStringMap = new HashMap<>();
// 调整时间格式
String s3 = s.replaceAll("%3A", ":");
// 获取 map
String s4 = s3.replace("+", " ");
String[] split = s4.split("&");
for (String s1 : split) {
String[] split1 = s1.split("=");
stringStringMap.put(split1[0], split1[1]);
}
return stringStringMap;
}
}
package pay.core.client.impl.alipay;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import lombok.Data;
import javax.validation.ConstraintViolation;
import javax.validation.Validator;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.util.Set;
// TODO 芋艿:参数校验
/**
* 支付宝的 PayClientConfig 实现类
* 属性主要来自 {@link com.alipay.api.AlipayConfig} 的必要属性
*
* @author 芋道源码
*/
@Data
public class AlipayPayClientConfig implements PayClientConfig {
/**
* 网关地址 - 线上
*/
public static final String SERVER_URL_PROD = "https://openapi.alipay.com/gateway.do";
/**
* 网关地址 - 沙箱
*/
public static final String SERVER_URL_SANDBOX = "https://openapi.alipaydev.com/gateway.do";
/**
* 公钥类型 - 公钥模式
*/
public static final Integer MODE_PUBLIC_KEY = 1;
/**
* 公钥类型 - 证书模式
*/
public static final Integer MODE_CERTIFICATE = 2;
/**
* 签名算法类型 - RSA
*/
public static final String SIGN_TYPE_DEFAULT = "RSA2";
/**
* 网关地址
* 1. {@link #SERVER_URL_PROD}
* 2. {@link #SERVER_URL_SANDBOX}
*/
@NotBlank(message = "网关地址不能为空", groups = {ModePublicKey.class, ModeCertificate.class})
private String serverUrl;
/**
* 开放平台上创建的应用的 ID
*/
@NotBlank(message = "开放平台上创建的应用的 ID不能为空", groups = {ModePublicKey.class, ModeCertificate.class})
private String appId;
/**
* 签名算法类型,推荐:RSA2
* <p>
* {@link #SIGN_TYPE_DEFAULT}
*/
@NotBlank(message = "签名算法类型不能为空", groups = {ModePublicKey.class, ModeCertificate.class})
private String signType;
/**
* 公钥类型
* 1. {@link #MODE_PUBLIC_KEY} 情况,privateKey + alipayPublicKey
* 2. {@link #MODE_CERTIFICATE} 情况,appCertContent + alipayPublicCertContent + rootCertContent
*/
@NotNull(message = "公钥类型不能为空", groups = {ModePublicKey.class, ModeCertificate.class})
private Integer mode;
// ========== 公钥模式 ==========
/**
* 商户私钥
*/
@NotBlank(message = "商户私钥不能为空", groups = {ModePublicKey.class})
private String privateKey;
/**
* 支付宝公钥字符串
*/
@NotBlank(message = "支付宝公钥字符串不能为空", groups = {ModePublicKey.class})
private String alipayPublicKey;
// ========== 证书模式 ==========
/**
* 指定商户公钥应用证书内容字符串
*/
@NotBlank(message = "指定商户公钥应用证书内容不能为空", groups = {ModeCertificate.class})
private String appCertContent;
/**
* 指定支付宝公钥证书内容字符串
*/
@NotBlank(message = "指定支付宝公钥证书内容不能为空", groups = {ModeCertificate.class})
private String alipayPublicCertContent;
/**
* 指定根证书内容字符串
*/
@NotBlank(message = "指定根证书内容字符串不能为空", groups = {ModeCertificate.class})
private String rootCertContent;
public interface ModePublicKey {
}
public interface ModeCertificate {
}
@Override
public Set<ConstraintViolation<PayClientConfig>> verifyParam(Validator validator) {
return validator.validate(this,
MODE_PUBLIC_KEY.equals(this.getMode()) ? ModePublicKey.class : ModeCertificate.class);
}
}
package pay.core.client.impl.alipay;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import java.util.Objects;
/**
* 支付宝的 PayCodeMapping 实现类
*
* @author 芋道源码
*/
public class AlipayPayCodeMapping extends AbstractPayCodeMapping {
@Override
protected ErrorCode apply0(String apiCode, String apiMsg) {
if (Objects.equals(apiCode, "10000")) {
return GlobalErrorCodeConstants.SUCCESS;
}
// alipay wap api code 返回为null, 暂时定为-9999
if (Objects.equals(apiCode, "-9999")) {
return GlobalErrorCodeConstants.SUCCESS;
}
return null;
}
}
package pay.core.client.impl.alipay;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayOrderUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AbstractAlipayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.alipay.api.AlipayApiException;
import com.alipay.api.domain.AlipayTradePrecreateModel;
import com.alipay.api.request.AlipayTradePrecreateRequest;
import com.alipay.api.response.AlipayTradePrecreateResponse;
import lombok.extern.slf4j.Slf4j;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
/**
* 支付宝【扫码支付】的 PayClient 实现类
* 文档:https://opendocs.alipay.com/apis/02890k
*
* @author 芋道源码
*/
@Slf4j
public class AlipayQrPayClient extends AbstractAlipayClient {
public AlipayQrPayClient(Long channelId, AlipayPayClientConfig config) {
super(channelId, PayChannelEnum.ALIPAY_QR.getCode(), config, new AlipayPayCodeMapping());
}
@Override
public PayCommonResult<AlipayTradePrecreateResponse> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
// 构建 AlipayTradePrecreateModel 请求
AlipayTradePrecreateModel model = new AlipayTradePrecreateModel();
model.setOutTradeNo(reqDTO.getMerchantOrderId());
model.setSubject(reqDTO.getSubject());
model.setBody(reqDTO.getBody());
model.setTotalAmount(calculateAmount(reqDTO.getAmount()).toString()); // 单位:元
// TODO 芋艿:userIp + expireTime
// 构建 AlipayTradePrecreateRequest
AlipayTradePrecreateRequest request = new AlipayTradePrecreateRequest();
request.setBizModel(model);
request.setNotifyUrl(reqDTO.getNotifyUrl());
request.setReturnUrl(reqDTO.getReturnUrl());
// 执行请求
AlipayTradePrecreateResponse response;
try {
response = client.execute(request);
} catch (AlipayApiException e) {
log.error("[unifiedOrder][request({}) 发起支付失败]", toJsonString(reqDTO), e);
return PayCommonResult.build(e.getErrCode(), e.getErrMsg(), null, codeMapping);
}
// TODO 芋艿:sub Code 需要测试下各种失败的情况
return PayCommonResult.build(response.getCode(), response.getMsg(), response, codeMapping);
}
}
package pay.core.client.impl.alipay;
import cn.hutool.core.date.DateUtil;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.PayOrderUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AbstractAlipayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayCodeMapping;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.alipay.api.AlipayApiException;
import com.alipay.api.domain.AlipayTradeWapPayModel;
import com.alipay.api.request.AlipayTradeWapPayRequest;
import com.alipay.api.response.AlipayTradeWapPayResponse;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
/**
* 支付宝【手机网站】的 PayClient 实现类
* 文档:https://opendocs.alipay.com/apis/api_1/alipay.trade.wap.pay
*
* @author 芋道源码
*/
@Slf4j
public class AlipayWapPayClient extends AbstractAlipayClient {
public AlipayWapPayClient(Long channelId, AlipayPayClientConfig config) {
super(channelId, PayChannelEnum.ALIPAY_WAP.getCode(), config, new AlipayPayCodeMapping());
}
@Override
public PayCommonResult<AlipayTradeWapPayResponse> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
// 构建 AlipayTradeWapPayModel 请求
AlipayTradeWapPayModel model = new AlipayTradeWapPayModel();
model.setOutTradeNo(reqDTO.getMerchantOrderId());
model.setSubject(reqDTO.getSubject());
model.setBody(reqDTO.getBody());
model.setTotalAmount(calculateAmount(reqDTO.getAmount()).toString());
model.setProductCode("QUICK_WAP_PAY"); // TODO 芋艿:这里咋整
//TODO 芋艿:这里咋整 jason @芋艿 可以去掉吧,
// TODO 芋艿 似乎这里不用传sellerId
// https://opendocs.alipay.com/apis/api_1/alipay.trade.wap.pay
//model.setSellerId("2088102147948060");
model.setTimeExpire(DateUtil.format(reqDTO.getExpireTime(),"yyyy-MM-dd HH:mm:ss"));
// TODO 芋艿:userIp
// 构建 AlipayTradeWapPayRequest
AlipayTradeWapPayRequest request = new AlipayTradeWapPayRequest();
request.setBizModel(model);
request.setNotifyUrl(reqDTO.getNotifyUrl());
request.setReturnUrl(reqDTO.getReturnUrl());
// 执行请求
AlipayTradeWapPayResponse response;
try {
response = client.pageExecute(request);
} catch (AlipayApiException e) {
return PayCommonResult.build(e.getErrCode(), e.getErrMsg(), null, codeMapping);
}
// TODO 芋艿:sub Code
if(response.isSuccess() && Objects.isNull(response.getCode()) && Objects.nonNull(response.getBody())){
//成功alipay wap 成功 code 为 null , body 为form 表单
return PayCommonResult.build("-9999", "Success", response, codeMapping);
}else {
return PayCommonResult.build(response.getCode(), response.getMsg(), response, codeMapping);
}
}
}
package pay.core.client.impl.wx;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import cn.iocoder.yudao.framework.pay.core.client.AbstractPayCodeMapping;
import java.util.Objects;
import static cn.iocoder.yudao.framework.pay.core.enums.PayFrameworkErrorCodeConstants.*;
/**
* 微信支付 PayCodeMapping 实现类
*
* @author 芋道源码
*/
public class WXCodeMapping extends AbstractPayCodeMapping {
/**
* 错误码 - 成功
* 由于 weixin-java-pay 封装的 Result 未返回 code,所以自己定义下
*/
public static final String CODE_SUCCESS = "SUCCESS";
/**
* 错误提示 - 成功
*/
public static final String MESSAGE_SUCCESS = "成功";
@Override
protected ErrorCode apply0(String apiCode, String apiMsg) {
if (Objects.equals(apiCode, CODE_SUCCESS)) {
return GlobalErrorCodeConstants.SUCCESS;
}
if (Objects.equals(apiCode, "FAIL")) {
if (Objects.equals(apiMsg, "AppID不存在,请检查后再试")) {
return PAY_CONFIG_APP_ID_ERROR;
}
if (Objects.equals(apiMsg, "签名错误,请检查后再试")
|| Objects.equals(apiMsg, "签名错误")) {
return PAY_CONFIG_SIGN_ERROR;
}
}
if (Objects.equals(apiCode, "PARAM_ERROR")) {
if (Objects.equals(apiMsg, "无效的openid")) {
return PAY_OPENID_ERROR;
}
}
if (Objects.equals(apiCode, "CustomErrorCode")) {
if (StrUtil.contains(apiMsg, "必填字段")) {
return PAY_PARAM_MISSING;
}
}
return null;
}
}
package pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.io.FileUtils;
import cn.iocoder.yudao.framework.common.util.object.ObjectUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
import cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyResult;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyV3Result;
import com.github.binarywang.wxpay.bean.order.WxPayMpOrderResult;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderRequest;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderV3Request;
import com.github.binarywang.wxpay.bean.result.WxPayUnifiedOrderV3Result;
import com.github.binarywang.wxpay.bean.result.enums.TradeTypeEnum;
import com.github.binarywang.wxpay.config.WxPayConfig;
import com.github.binarywang.wxpay.constant.WxPayConstants;
import com.github.binarywang.wxpay.exception.WxPayException;
import com.github.binarywang.wxpay.service.WxPayService;
import com.github.binarywang.wxpay.service.impl.WxPayServiceImpl;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.CODE_SUCCESS;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.MESSAGE_SUCCESS;
/**
* 微信小程序下支付
*
* @author zwy
*/
@Slf4j
public class WXLitePayClient extends AbstractPayClient<cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig> {
private WxPayService client;
public WXLitePayClient(Long channelId, cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig config) {
super(channelId, PayChannelEnum.WX_LITE.getCode(), config, new WXCodeMapping());
}
@Override
protected void doInit() {
WxPayConfig payConfig = new WxPayConfig();
BeanUtil.copyProperties(config, payConfig, "keyContent");
payConfig.setTradeType(WxPayConstants.TradeType.JSAPI); // 设置使用 JS API 支付方式
// if (StrUtil.isNotEmpty(config.getKeyContent())) {
// payConfig.setKeyContent(config.getKeyContent().getBytes(StandardCharsets.UTF_8));
// }
if (StrUtil.isNotEmpty(config.getPrivateKeyContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateKeyPath(FileUtils.createTempFile(config.getPrivateKeyContent()).getPath());
}
if (StrUtil.isNotEmpty(config.getPrivateCertContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateCertPath(FileUtils.createTempFile(config.getPrivateCertContent()).getPath());
}
// 真实客户端
this.client = new WxPayServiceImpl();
client.setConfig(payConfig);
}
@Override
public PayCommonResult<WxPayMpOrderResult> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
WxPayMpOrderResult response;
try {
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
response = this.unifiedOrderV2(reqDTO);
break;
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V3:
WxPayUnifiedOrderV3Result.JsapiResult responseV3 = this.unifiedOrderV3(reqDTO);
// 将 V3 的结果,统一转换成 V2。返回的字段是一致的
response = new WxPayMpOrderResult();
BeanUtil.copyProperties(responseV3, response, true);
break;
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
} catch (WxPayException e) {
log.error("[unifiedOrder][request({}) 发起支付失败,原因({})]", toJsonString(reqDTO), e);
return PayCommonResult.build(ObjectUtils.defaultIfNull(e.getErrCode(), e.getReturnCode(), "CustomErrorCode"),
ObjectUtils.defaultIfNull(e.getErrCodeDes(), e.getCustomErrorMsg()), null, codeMapping);
}
return PayCommonResult.build(CODE_SUCCESS, MESSAGE_SUCCESS, response, codeMapping);
}
private WxPayMpOrderResult unifiedOrderV2(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderRequest request = WxPayUnifiedOrderRequest.newBuilder()
.outTradeNo(reqDTO.getMerchantOrderId())
.body(reqDTO.getBody())
.totalFee(reqDTO.getAmount().intValue()) // 单位分
.timeExpire(DateUtil.format(reqDTO.getExpireTime(), "yyyyMMddHHmmss")) // v2的时间格式
.spbillCreateIp(reqDTO.getUserIp())
.openid(getOpenid(reqDTO))
.notifyUrl(reqDTO.getNotifyUrl())
.build();
// 执行请求
return client.createOrder(request);
}
private WxPayUnifiedOrderV3Result.JsapiResult unifiedOrderV3(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderV3Request request = new WxPayUnifiedOrderV3Request();
request.setOutTradeNo(reqDTO.getMerchantOrderId());
request.setDescription(reqDTO.getBody());
request.setAmount(new WxPayUnifiedOrderV3Request
.Amount()
.setTotal(reqDTO
.getAmount()
.intValue())); // 单位分
request.setTimeExpire(DateUtil.format(reqDTO.getExpireTime(), "yyyy-MM-dd'T'HH:mm:ssXXX")); // v3的时间格式
request.setPayer(new WxPayUnifiedOrderV3Request.Payer().setOpenid(getOpenid(reqDTO)));
request.setSceneInfo(new WxPayUnifiedOrderV3Request.SceneInfo().setPayerClientIp(reqDTO.getUserIp()));
request.setNotifyUrl(reqDTO.getNotifyUrl());
// 执行请求
return client.createOrderV3(TradeTypeEnum.JSAPI, request);
}
private static String getOpenid(PayOrderUnifiedReqDTO reqDTO) {
String openid = MapUtil.getStr(reqDTO.getChannelExtras(), "openid");
if (StrUtil.isEmpty(openid)) {
throw new IllegalArgumentException("支付请求的 openid 不能为空!");
}
return openid;
}
/**
*
* 微信支付回调 分 v2 和v3 的处理方式
*
* @param data 通知结果
* @return 支付回调对象
* @throws WxPayException 微信异常类
*/
@Override
public PayOrderNotifyRespDTO parseOrderNotify(PayNotifyDataDTO data) throws WxPayException {
log.info("[parseOrderNotify][微信支付回调data数据:{}]", data.getBody());
// 微信支付 v2 回调结果处理
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
return parseOrderNotifyV2(data);
case WXPayClientConfig.API_VERSION_V3:
return parseOrderNotifyV3(data);
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
}
private PayOrderNotifyRespDTO parseOrderNotifyV3(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyV3Result wxPayOrderNotifyV3Result = client.parseOrderNotifyV3Result(data.getBody(), null);
WxPayOrderNotifyV3Result.DecryptNotifyResult result = wxPayOrderNotifyV3Result.getResult();
// 转换结果
Assert.isTrue(Objects.equals(wxPayOrderNotifyV3Result.getResult().getTradeState(), "SUCCESS"),
"支付结果非 SUCCESS");
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
private PayOrderNotifyRespDTO parseOrderNotifyV2(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyResult notifyResult = client.parseOrderNotifyResult(data.getBody());
Assert.isTrue(Objects.equals(notifyResult.getResultCode(), "SUCCESS"), "支付结果非 SUCCESS");
// 转换结果
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
}
@Override
public PayRefundNotifyDTO parseRefundNotify(PayNotifyDataDTO notifyData) {
//TODO 需要实现
throw new UnsupportedOperationException("需要实现");
}
@Override
protected PayCommonResult<PayRefundUnifiedRespDTO> doUnifiedRefund(PayRefundUnifiedReqDTO reqDTO) throws Throwable {
//TODO 需要实现
throw new UnsupportedOperationException();
}
}
package pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.io.FileUtils;
import cn.iocoder.yudao.framework.common.util.object.ObjectUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
import cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyResult;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyV3Result;
import com.github.binarywang.wxpay.bean.order.WxPayNativeOrderResult;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderRequest;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderV3Request;
import com.github.binarywang.wxpay.bean.result.enums.TradeTypeEnum;
import com.github.binarywang.wxpay.config.WxPayConfig;
import com.github.binarywang.wxpay.constant.WxPayConstants;
import com.github.binarywang.wxpay.exception.WxPayException;
import com.github.binarywang.wxpay.service.WxPayService;
import com.github.binarywang.wxpay.service.impl.WxPayServiceImpl;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.CODE_SUCCESS;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.MESSAGE_SUCCESS;
/**
* 微信 App 支付
*
* @author zwy
*/
@Slf4j
public class WXNativePayClient extends AbstractPayClient<cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig> {
private WxPayService client;
public WXNativePayClient(Long channelId, cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig config) {
super(channelId, PayChannelEnum.WX_NATIVE.getCode(), config, new WXCodeMapping());
}
@Override
protected void doInit() {
WxPayConfig payConfig = new WxPayConfig();
BeanUtil.copyProperties(config, payConfig, "keyContent");
payConfig.setTradeType(WxPayConstants.TradeType.NATIVE); // 设置使用 native 支付方式
// if (StrUtil.isNotEmpty(config.getKeyContent())) {
// payConfig.setKeyContent(config.getKeyContent().getBytes(StandardCharsets.UTF_8));
// }
if (StrUtil.isNotEmpty(config.getPrivateKeyContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateKeyPath(FileUtils.createTempFile(config.getPrivateKeyContent()).getPath());
}
if (StrUtil.isNotEmpty(config.getPrivateCertContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateCertPath(FileUtils.createTempFile(config.getPrivateCertContent()).getPath());
}
// 真实客户端
this.client = new WxPayServiceImpl();
client.setConfig(payConfig);
}
@Override
public PayCommonResult<String> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
// 这里原生的返回的是支付的 url 所以直接使用string接收
// "invokeResponse": "weixin://wxpay/bizpayurl?pr=EGYAem7zz"
String responseV3;
try {
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
responseV3 = unifiedOrderV2(reqDTO).getCodeUrl();
break;
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V3:
responseV3 = this.unifiedOrderV3(reqDTO);
break;
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
} catch (WxPayException e) {
log.error("[unifiedOrder][request({}) 发起支付失败,原因({})]", toJsonString(reqDTO), e);
return PayCommonResult.build(ObjectUtils.defaultIfNull(e.getErrCode(), e.getReturnCode(), "CustomErrorCode"),
ObjectUtils.defaultIfNull(e.getErrCodeDes(), e.getCustomErrorMsg()), null, codeMapping);
}
return PayCommonResult.build(CODE_SUCCESS, MESSAGE_SUCCESS, responseV3, codeMapping);
}
private WxPayNativeOrderResult unifiedOrderV2(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
//前端
String tradeType = reqDTO.getChannelExtras().get("trade_type");
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderRequest request = WxPayUnifiedOrderRequest
.newBuilder()
.outTradeNo(reqDTO.getMerchantOrderId())
.body(reqDTO.getBody())
.totalFee(reqDTO.getAmount().intValue()) // 单位分
.timeExpire(DateUtil.format(reqDTO.getExpireTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.spbillCreateIp(reqDTO.getUserIp())
.notifyUrl(reqDTO.getNotifyUrl())
.productId(tradeType)
.build();
// 执行请求
return client.createOrder(request);
}
private String unifiedOrderV3(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderV3Request request = new WxPayUnifiedOrderV3Request();
request.setOutTradeNo(reqDTO.getMerchantOrderId());
request.setDescription(reqDTO.getBody());
request.setAmount(new WxPayUnifiedOrderV3Request.Amount().setTotal(reqDTO.getAmount().intValue())); // 单位分
request.setSceneInfo(new WxPayUnifiedOrderV3Request.SceneInfo().setPayerClientIp(reqDTO.getUserIp()));
request.setNotifyUrl(reqDTO.getNotifyUrl());
// 执行请求
return client.createOrderV3(TradeTypeEnum.NATIVE, request);
}
/**
*
* 微信支付回调 分v2 和v3 的处理方式
*
* @param data 通知结果
* @return 支付回调对象
* @throws WxPayException 微信异常类
*/
@Override
public PayOrderNotifyRespDTO parseOrderNotify(PayNotifyDataDTO data) throws WxPayException {
log.info("微信支付回调data数据:{}", data.getBody());
// 微信支付 v2 回调结果处理
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
return parseOrderNotifyV2(data);
case WXPayClientConfig.API_VERSION_V3:
return parseOrderNotifyV3(data);
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
}
private PayOrderNotifyRespDTO parseOrderNotifyV3(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyV3Result wxPayOrderNotifyV3Result = client.parseOrderNotifyV3Result(data.getBody(), null);
WxPayOrderNotifyV3Result.DecryptNotifyResult result = wxPayOrderNotifyV3Result.getResult();
// 转换结果
Assert.isTrue(Objects.equals(wxPayOrderNotifyV3Result.getResult().getTradeState(), "SUCCESS"),
"支付结果非 SUCCESS");
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
private PayOrderNotifyRespDTO parseOrderNotifyV2(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyResult notifyResult = client.parseOrderNotifyResult(data.getBody());
Assert.isTrue(Objects.equals(notifyResult.getResultCode(), "SUCCESS"), "支付结果非 SUCCESS");
// 转换结果
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
}
@Override
public PayRefundNotifyDTO parseRefundNotify(PayNotifyDataDTO notifyData) {
// TODO 需要实现
throw new UnsupportedOperationException("需要实现");
}
@Override
protected PayCommonResult<PayRefundUnifiedRespDTO> doUnifiedRefund(PayRefundUnifiedReqDTO reqDTO) throws Throwable {
// TODO 需要实现
throw new UnsupportedOperationException();
}
}
package pay.core.client.impl.wx;
import cn.hutool.core.io.IoUtil;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import lombok.Data;
import javax.validation.ConstraintViolation;
import javax.validation.Validator;
import javax.validation.constraints.NotBlank;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.Set;
/**
* 微信支付的 PayClientConfig 实现类
* 属性主要来自 {@link com.github.binarywang.wxpay.config.WxPayConfig} 的必要属性
*
* @author 芋道源码
*/
@Data
public class WXPayClientConfig implements PayClientConfig {
/**
* API 版本 - V2
* https://pay.weixin.qq.com/wiki/doc/api/jsapi.php?chapter=4_1
*/
public static final String API_VERSION_V2 = "v2";
/**
* API 版本 - V3
* https://pay.weixin.qq.com/wiki/doc/apiv3/wechatpay/wechatpay-1.shtml
*/
public static final String API_VERSION_V3 = "v3";
/**
* 公众号或者小程序的 appid
*/
@NotBlank(message = "APPID 不能为空", groups = {V2.class, V3.class})
private String appId;
/**
* 商户号
*/
@NotBlank(message = "商户号 不能为空", groups = {V2.class, V3.class})
private String mchId;
/**
* API 版本
*/
@NotBlank(message = "API 版本 不能为空", groups = {V2.class, V3.class})
private String apiVersion;
// ========== V2 版本的参数 ==========
/**
* 商户密钥
*/
@NotBlank(message = "商户密钥 不能为空", groups = V2.class)
private String mchKey;
/**
* apiclient_cert.p12 证书文件的绝对路径或者以 classpath: 开头的类路径.
* 对应的字符串
*
* 注意,可通过 {@link #main(String[])} 读取
*/
/// private String keyContent;
// ========== V3 版本的参数 ==========
/**
* apiclient_key.pem 证书文件的绝对路径或者以 classpath: 开头的类路径.
* 对应的字符串
* 注意,可通过 {@link #main(String[])} 读取
*/
@NotBlank(message = "apiclient_key 不能为空", groups = V3.class)
private String privateKeyContent;
/**
* apiclient_cert.pem 证书文件的绝对路径或者以 classpath: 开头的类路径.
* 对应的字符串
* <p>
* 注意,可通过 {@link #main(String[])} 读取
*/
@NotBlank(message = "apiclient_cert 不能为空", groups = V3.class)
private String privateCertContent;
/**
* apiV3 密钥值
*/
@NotBlank(message = "apiV3 密钥值 不能为空", groups = V3.class)
private String apiV3Key;
/**
* 分组校验 v2版本
*/
public interface V2 {
}
/**
* 分组校验 v3版本
*/
public interface V3 {
}
@Override
public Set<ConstraintViolation<PayClientConfig>> verifyParam(Validator validator) {
return validator.validate(this, this.getApiVersion().equals(API_VERSION_V2) ? V2.class : V3.class);
}
public static void main(String[] args) throws FileNotFoundException {
String path = "/Users/yunai/Downloads/wx_pay/apiclient_cert.p12";
/// String path = "/Users/yunai/Downloads/wx_pay/apiclient_key.pem";
/// String path = "/Users/yunai/Downloads/wx_pay/apiclient_cert.pem";
System.out.println(IoUtil.readUtf8(new FileInputStream(path)));
}
}
package pay.core.client.impl.wx;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.io.FileUtils;
import cn.iocoder.yudao.framework.common.util.object.ObjectUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayCommonResult;
import cn.iocoder.yudao.framework.pay.core.client.dto.*;
import cn.iocoder.yudao.framework.pay.core.client.impl.AbstractPayClient;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.enums.PayChannelEnum;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyResult;
import com.github.binarywang.wxpay.bean.notify.WxPayOrderNotifyV3Result;
import com.github.binarywang.wxpay.bean.order.WxPayMpOrderResult;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderRequest;
import com.github.binarywang.wxpay.bean.request.WxPayUnifiedOrderV3Request;
import com.github.binarywang.wxpay.bean.result.WxPayUnifiedOrderV3Result;
import com.github.binarywang.wxpay.bean.result.enums.TradeTypeEnum;
import com.github.binarywang.wxpay.config.WxPayConfig;
import com.github.binarywang.wxpay.constant.WxPayConstants;
import com.github.binarywang.wxpay.exception.WxPayException;
import com.github.binarywang.wxpay.service.WxPayService;
import com.github.binarywang.wxpay.service.impl.WxPayServiceImpl;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.util.json.JsonUtils.toJsonString;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.CODE_SUCCESS;
import static cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXCodeMapping.MESSAGE_SUCCESS;
/**
* 微信支付(公众号)的 PayClient 实现类
*
* @author 芋道源码
*/
@Slf4j
public class WXPubPayClient extends AbstractPayClient<cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig> {
private WxPayService client;
public WXPubPayClient(Long channelId, cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig config) {
super(channelId, PayChannelEnum.WX_PUB.getCode(), config, new WXCodeMapping());
}
@Override
protected void doInit() {
WxPayConfig payConfig = new WxPayConfig();
BeanUtil.copyProperties(config, payConfig, "keyContent");
payConfig.setTradeType(WxPayConstants.TradeType.JSAPI); // 设置使用 JS API 支付方式
// if (StrUtil.isNotEmpty(config.getKeyContent())) {
// payConfig.setKeyContent(config.getKeyContent().getBytes(StandardCharsets.UTF_8));
// }
if (StrUtil.isNotEmpty(config.getPrivateKeyContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateKeyPath(FileUtils.createTempFile(config.getPrivateKeyContent()).getPath());
}
if (StrUtil.isNotEmpty(config.getPrivateCertContent())) {
// weixin-pay-java 存在 BUG,无法直接设置内容,所以创建临时文件来解决
payConfig.setPrivateCertPath(FileUtils.createTempFile(config.getPrivateCertContent()).getPath());
}
// 真实客户端
this.client = new WxPayServiceImpl();
client.setConfig(payConfig);
}
@Override
public PayCommonResult<WxPayMpOrderResult> doUnifiedOrder(PayOrderUnifiedReqDTO reqDTO) {
WxPayMpOrderResult response;
try {
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
response = this.unifiedOrderV2(reqDTO);
break;
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V3:
WxPayUnifiedOrderV3Result.JsapiResult responseV3 = this.unifiedOrderV3(reqDTO);
// 将 V3 的结果,统一转换成 V2。返回的字段是一致的
response = new WxPayMpOrderResult();
BeanUtil.copyProperties(responseV3, response, true);
break;
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
} catch (WxPayException e) {
log.error("[unifiedOrder][request({}) 发起支付失败,原因({})]", toJsonString(reqDTO), e);
return PayCommonResult.build(ObjectUtils.defaultIfNull(e.getErrCode(), e.getReturnCode(), "CustomErrorCode"),
ObjectUtils.defaultIfNull(e.getErrCodeDes(), e.getCustomErrorMsg()),null, codeMapping);
}
return PayCommonResult.build(CODE_SUCCESS, MESSAGE_SUCCESS, response, codeMapping);
}
private WxPayMpOrderResult unifiedOrderV2(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderRequest request = WxPayUnifiedOrderRequest.newBuilder()
.outTradeNo(reqDTO.getMerchantOrderId())
.body(reqDTO.getBody())
.totalFee(reqDTO.getAmount().intValue()) // 单位分
.timeExpire(DateUtil.format(reqDTO.getExpireTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.spbillCreateIp(reqDTO.getUserIp())
.openid(getOpenid(reqDTO))
.notifyUrl(reqDTO.getNotifyUrl())
.build();
// 执行请求
return client.createOrder(request);
}
private WxPayUnifiedOrderV3Result.JsapiResult unifiedOrderV3(PayOrderUnifiedReqDTO reqDTO) throws WxPayException {
// 构建 WxPayUnifiedOrderRequest 对象
WxPayUnifiedOrderV3Request request = new WxPayUnifiedOrderV3Request();
request.setOutTradeNo(reqDTO.getMerchantOrderId());
request.setDescription(reqDTO.getBody());
request.setAmount(new WxPayUnifiedOrderV3Request.Amount().setTotal(reqDTO.getAmount().intValue())); // 单位分
request.setTimeExpire(DateUtil.format(reqDTO.getExpireTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"));
request.setPayer(new WxPayUnifiedOrderV3Request.Payer().setOpenid(getOpenid(reqDTO)));
request.setSceneInfo(new WxPayUnifiedOrderV3Request.SceneInfo().setPayerClientIp(reqDTO.getUserIp()));
request.setNotifyUrl(reqDTO.getNotifyUrl());
// 执行请求
return client.createOrderV3(TradeTypeEnum.JSAPI, request);
}
private static String getOpenid(PayOrderUnifiedReqDTO reqDTO) {
String openid = MapUtil.getStr(reqDTO.getChannelExtras(), "openid");
if (StrUtil.isEmpty(openid)) {
throw new IllegalArgumentException("支付请求的 openid 不能为空!");
}
return openid;
}
/**
*
* 微信支付回调 分v2 和v3 的处理方式
*
* @param data 通知结果
* @return 支付回调对象
* @throws WxPayException 微信异常类
*/
@Override
public PayOrderNotifyRespDTO parseOrderNotify(PayNotifyDataDTO data) throws WxPayException {
log.info("[parseOrderNotify][微信支付回调data数据: {}]", data.getBody());
// 微信支付 v2 回调结果处理
switch (config.getApiVersion()) {
case cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig.API_VERSION_V2:
return parseOrderNotifyV2(data);
case WXPayClientConfig.API_VERSION_V3:
return parseOrderNotifyV3(data);
default:
throw new IllegalArgumentException(String.format("未知的 API 版本(%s)", config.getApiVersion()));
}
}
private PayOrderNotifyRespDTO parseOrderNotifyV3(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyV3Result wxPayOrderNotifyV3Result = client.parseOrderNotifyV3Result(data.getBody(), null);
WxPayOrderNotifyV3Result.DecryptNotifyResult result = wxPayOrderNotifyV3Result.getResult();
// 转换结果
Assert.isTrue(Objects.equals(wxPayOrderNotifyV3Result.getResult().getTradeState(), "SUCCESS"),
"支付结果非 SUCCESS");
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(result.getOutTradeNo())
.channelOrderNo(result.getTradeState())
.successTime(LocalDateTimeUtil.parse(result.getSuccessTime(), "yyyy-MM-dd'T'HH:mm:ssXXX"))
.data(data.getBody())
.build();
}
private PayOrderNotifyRespDTO parseOrderNotifyV2(PayNotifyDataDTO data) throws WxPayException {
WxPayOrderNotifyResult notifyResult = client.parseOrderNotifyResult(data.getBody());
Assert.isTrue(Objects.equals(notifyResult.getResultCode(), "SUCCESS"), "支付结果非 SUCCESS");
// 转换结果
return PayOrderNotifyRespDTO
.builder()
.orderExtensionNo(notifyResult.getOutTradeNo())
.channelOrderNo(notifyResult.getTransactionId())
.channelUserId(notifyResult.getOpenid())
.successTime(LocalDateTimeUtil.parse(notifyResult.getTimeEnd(), "yyyyMMddHHmmss"))
.data(data.getBody())
.build();
}
@Override
public PayRefundNotifyDTO parseRefundNotify(PayNotifyDataDTO notifyData) {
// TODO 需要实现
throw new UnsupportedOperationException("需要实现");
}
@Override
protected PayCommonResult<PayRefundUnifiedRespDTO> doUnifiedRefund(PayRefundUnifiedReqDTO reqDTO) throws Throwable {
// TODO 需要实现
throw new UnsupportedOperationException();
}
}
package pay.core.enums;
import cn.hutool.core.util.ArrayUtil;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.wx.WXPayClientConfig;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 支付渠道的编码的枚举
* 枚举值
*
* @author 芋道源码
*/
@Getter
@AllArgsConstructor
public enum PayChannelEnum {
WX_PUB("wx_pub", "微信 JSAPI 支付", WXPayClientConfig.class), // 公众号网页
WX_LITE("wx_lite", "微信小程序支付", WXPayClientConfig.class),
WX_APP("wx_app", "微信 App 支付", WXPayClientConfig.class),
WX_NATIVE("wx_native", "微信 native 支付", WXPayClientConfig.class),
ALIPAY_PC("alipay_pc", "支付宝 PC 网站支付", AlipayPayClientConfig.class),
ALIPAY_WAP("alipay_wap", "支付宝 Wap 网站支付", AlipayPayClientConfig.class),
ALIPAY_APP("alipay_app", "支付宝App 支付", AlipayPayClientConfig.class),
ALIPAY_QR("alipay_qr", "支付宝扫码支付", AlipayPayClientConfig.class);
/**
* 编码
* <p>
* 参考 https://www.pingxx.com/api/支付渠道属性值.html
*/
private final String code;
/**
* 名字
*/
private final String name;
/**
* 配置类
*/
private final Class<? extends PayClientConfig> configClass;
/**
* 微信支付
*/
public static final String WECHAT = "WECHAT";
/**
* 支付宝支付
*/
public static final String ALIPAY = "ALIPAY";
public static PayChannelEnum getByCode(String code) {
return ArrayUtil.firstMatch(o -> o.getCode().equals(code), values());
}
}
package pay.core.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 渠道统一的退款返回结果
*
* @author jason
*/
@Getter
@AllArgsConstructor
public enum PayChannelRefundRespEnum {
SUCCESS(1, "退款成功"),
FAILURE(2, "退款失败"),
PROCESSING(3,"退款处理中"),
CLOSED(4, "退款关闭");
private final Integer status;
private final String name;
}
package pay.core.enums;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
/**
* 支付框架的错误码枚举
*
* 短信框架,使用 2-002-000-000 段
*
* @author 芋道源码
*/
public interface PayFrameworkErrorCodeConstants {
ErrorCode PAY_UNKNOWN = new ErrorCode(2002000000, "未知错误,需要解析");
// ========== 配置相关相关 2002000100 ==========
ErrorCode PAY_CONFIG_APP_ID_ERROR = new ErrorCode(2002000100, "支付渠道 AppId 不正确");
ErrorCode PAY_CONFIG_SIGN_ERROR = new ErrorCode(2002000100, "签名错误"); // 例如说,微信支付,配置错了 mchId 或者 mchKey
// ========== 其它相关 2002000900 开头 ==========
ErrorCode PAY_OPENID_ERROR = new ErrorCode(2002000900, "无效的 openid"); // 例如说,微信 openid 未授权过
ErrorCode PAY_PARAM_MISSING = new ErrorCode(2002000901, "请求参数缺失"); // 例如说,支付少传了金额
ErrorCode EXCEPTION = new ErrorCode(2002000999, "调用异常");
}
package pay.core.enums;
/**
* 退款通知, 统一的渠道退款状态
*
* @author jason
*/
public enum PayNotifyRefundStatusEnum {
/**
* 支付宝 中 全额退款 trade_status=TRADE_CLOSED, 部分退款 trade_status=TRADE_SUCCESS
* 退款成功
*/
SUCCESS,
/**
* 支付宝退款通知没有这个状态
* 退款异常
*/
ABNORMAL;
}
......@@ -11,7 +11,6 @@ import cn.iocoder.yudao.framework.sms.core.client.dto.SmsTemplateRespDTO;
import cn.iocoder.yudao.framework.sms.core.enums.SmsTemplateAuditStatusEnum;
import cn.iocoder.yudao.framework.sms.core.property.SmsChannelProperties;
import cn.iocoder.yudao.framework.common.util.collection.MapUtils;
import cn.iocoder.yudao.framework.common.util.date.DateUtils;
import cn.iocoder.yudao.framework.sms.core.enums.SmsFrameworkErrorCodeConstants;
import com.aliyuncs.AcsRequest;
import com.aliyuncs.IAcsClient;
......@@ -27,6 +26,7 @@ import org.mockito.ArgumentMatcher;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import java.time.LocalDateTime;
import java.util.List;
import java.util.function.Function;
......@@ -125,7 +125,7 @@ public class AliyunSmsClientTest extends BaseMockitoUnitTest {
assertEquals("DELIVERED", statuses.get(0).getErrorCode());
assertEquals("用户接收成功", statuses.get(0).getErrorMsg());
assertEquals("13900000001", statuses.get(0).getMobile());
assertEquals(DateUtils.buildTime(2017, 2, 2, 22, 23, 24), statuses.get(0).getReceiveTime());
assertEquals(LocalDateTime.of(2017, 2, 2, 22, 23, 24), statuses.get(0).getReceiveTime());
assertEquals("12345", statuses.get(0).getSerialNo());
assertEquals(67890L, statuses.get(0).getLogId());
}
......@@ -181,7 +181,7 @@ public class AliyunSmsClientTest extends BaseMockitoUnitTest {
when(client.getAcsResponse(any(AcsRequest.class))).thenThrow(ex);
// 调用,并断言异常
SmsCommonResult<?> result = smsClient.invoke(request,null);
SmsCommonResult<?> result = smsClient.invoke(request, null);
// 断言
assertEquals(ex.getErrCode(), result.getApiCode());
assertEquals(ex.getErrMsg(), result.getApiMsg());
......
......@@ -6,7 +6,6 @@ import cn.iocoder.yudao.framework.common.core.KeyValue;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import cn.iocoder.yudao.framework.common.util.collection.ArrayUtils;
import cn.iocoder.yudao.framework.common.util.collection.MapUtils;
import cn.iocoder.yudao.framework.common.util.date.DateUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.sms.core.client.SmsCommonResult;
import cn.iocoder.yudao.framework.sms.core.client.dto.SmsReceiveRespDTO;
......@@ -25,6 +24,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
......@@ -146,7 +146,7 @@ public class TencentSmsClientTest extends BaseMockitoUnitTest {
assertEquals("DELIVRD", statuses.get(0).getErrorCode());
assertEquals("用户短信送达成功", statuses.get(0).getErrorMsg());
assertEquals("13900000001", statuses.get(0).getMobile());
assertEquals(DateUtils.buildTime(2015, 10, 17, 8, 3, 4), statuses.get(0).getReceiveTime());
assertEquals(LocalDateTime.of(2015, 10, 17, 8, 3, 4), statuses.get(0).getReceiveTime());
assertEquals("12345", statuses.get(0).getSerialNo());
assertEquals(67890L, statuses.get(0).getLogId());
}
......
......@@ -10,7 +10,6 @@ import cn.iocoder.yudao.framework.sms.core.client.dto.SmsSendRespDTO;
import cn.iocoder.yudao.framework.sms.core.client.dto.SmsTemplateRespDTO;
import cn.iocoder.yudao.framework.sms.core.enums.SmsTemplateAuditStatusEnum;
import cn.iocoder.yudao.framework.sms.core.property.SmsChannelProperties;
import cn.iocoder.yudao.framework.common.util.date.DateUtils;
import com.google.common.collect.Lists;
import com.yunpian.sdk.YunpianClient;
import com.yunpian.sdk.api.SmsApi;
......@@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
......@@ -115,7 +115,7 @@ public class YunpianSmsClientTest extends BaseMockitoUnitTest {
assertEquals("", statuses.get(0).getErrorCode());
assertNull(statuses.get(0).getErrorMsg());
assertEquals("15205201314", statuses.get(0).getMobile());
assertEquals(DateUtils.buildTime(2014, 3, 17, 22, 55, 21), statuses.get(0).getReceiveTime());
assertEquals(LocalDateTime.of(2014, 3, 17, 22, 55, 21), statuses.get(0).getReceiveTime());
assertEquals("9527", statuses.get(0).getSerialNo());
assertEquals(1024L, statuses.get(0).getLogId());
}
......
......@@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.SortingField;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
......@@ -33,7 +34,7 @@ public class MyBatisUtils {
// 排序字段
if (!CollectionUtil.isEmpty(sortingFields)) {
page.addOrder(sortingFields.stream().map(sortingField -> SortingField.ORDER_ASC.equals(sortingField.getOrder()) ?
OrderItem.asc(sortingField.getField()) : OrderItem.desc(sortingField.getField()))
OrderItem.asc(sortingField.getField()) : OrderItem.desc(sortingField.getField()))
.collect(Collectors.toList()));
}
return page;
......@@ -78,7 +79,10 @@ public class MyBatisUtils {
* @return Column 对象
*/
public static Column buildColumn(String tableName, Alias tableAlias, String column) {
return new Column(tableAlias != null ? tableAlias.getName() + "." + column : column);
if (tableAlias != null) {
tableName = tableAlias.getName();
}
return new Column(tableName + StringPool.DOT + column);
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论