001/*
002 *  Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.*;
029import com.mybatisflex.core.relation.RelationManager;
030import com.mybatisflex.core.table.TableInfo;
031import com.mybatisflex.core.table.TableInfoFactory;
032import org.apache.ibatis.exceptions.TooManyResultsException;
033import org.apache.ibatis.session.defaults.DefaultSqlSession;
034
035import java.util.ArrayList;
036import java.util.Collection;
037import java.util.Collections;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.List;
041import java.util.Map;
042import java.util.Set;
043import java.util.function.Consumer;
044
045import static com.mybatisflex.core.query.QueryMethods.count;
046
047public class MapperUtil {
048
049    private MapperUtil() {
050    }
051
052
053    /**
054     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
055     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
056     *
057     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
058     *
059     * <p><pre>
060     * {@code
061     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
062     * }
063     * </pre>
064     *
065     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
066     */
067    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
068        return QueryWrapper.create()
069            .select(count().as("total"))
070            .from(queryWrapper).as("t");
071    }
072    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper,List<QueryColumn> customCountColumns) {
073        return customCountColumns!=null?QueryWrapper.create()
074            .select(customCountColumns)
075            .from(queryWrapper).as("t"):rawCountQueryWrapper(queryWrapper);
076    }
077    /**
078     * 优化 COUNT 查询语句。
079     */
080    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
081        return optimizeCountQueryWrapper(queryWrapper, Collections.singletonList(count().as("total")));
082    }
083    /**
084     * 优化 COUNT 查询语句。
085     */
086    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper, List<QueryColumn> customCountColumns) {
087        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
088        QueryWrapper clone = queryWrapper.clone();
089
090        List<UnionWrapper> unions = CPI.getUnions(clone);
091        if(!CollectionUtil.isEmpty(unions)){
092            List<UnionWrapper> newUnions = new ArrayList<>(unions.size());
093            for (UnionWrapper union : unions) {
094                QueryWrapper unionQuery = optimizeCountQueryWrapper(union.getQueryWrapper().clone(),null);
095                UnionWrapper clone1 = union.clone();
096                clone1.setQueryWrapper(unionQuery);
097                newUnions.add(clone1);
098            }
099            CPI.setUnions(clone, newUnions);
100        }
101
102        // 将最后面的 order by 移除掉
103        CPI.setOrderBys(clone, null);
104        // 获取查询列和分组列,用于判断是否进行优化
105        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
106        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
107        QueryCondition havingCondition = CPI.getHavingQueryCondition(clone);
108        // 如果有 distinct、group by、having 等语句则不优化
109        // 这种一旦优化了就会造成 count 语句查询出来的值不对
110        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns) || havingCondition != null) {
111            return clone;
112        }
113        // 判断能不能清除 join 语句
114        if (canClearJoins(clone)) {
115            CPI.setJoins(clone, null);
116        }
117        // 将 select 里面的列换成 COUNT(*) AS `total`
118        if(customCountColumns!=null){
119            if(hasUnion(clone)){
120                return rawCountQueryWrapper(clone,customCountColumns);
121            }else {
122                CPI.setSelectColumns(clone, customCountColumns);
123            }
124        }
125        return clone;
126    }
127
128    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
129        if (CollectionUtil.isEmpty(selectColumns)) {
130            return false;
131        }
132        for (QueryColumn selectColumn : selectColumns) {
133            if (selectColumn instanceof DistinctQueryColumn) {
134                return true;
135            }
136        }
137        return false;
138    }
139
140    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
141        return CollectionUtil.isNotEmpty(groupByColumns);
142    }
143
144    private static boolean hasUnion(QueryWrapper countQueryWrapper) {
145        return CollectionUtil.isNotEmpty(CPI.getUnions(countQueryWrapper));
146    }
147
148    private static boolean canClearJoins(QueryWrapper queryWrapper) {
149        List<Join> joins = CPI.getJoins(queryWrapper);
150        if (CollectionUtil.isEmpty(joins)) {
151            return false;
152        }
153
154        // 只有全是 left join 语句才会清除 join
155        // 因为如果是 inner join 或 right join 往往都会放大记录数
156        for (Join join : joins) {
157            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
158                return false;
159            }
160        }
161
162        // 获取 join 语句中使用到的表名
163        List<String> joinTables = new ArrayList<>();
164        joins.forEach(join -> {
165            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
166            if (joinQueryTable != null) {
167                String tableName = joinQueryTable.getName();
168                if (StringUtil.isNotBlank(joinQueryTable.getAlias())) {
169                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
170                } else {
171                    joinTables.add(tableName);
172                }
173            }
174        });
175
176        // 获取 where 语句中的条件
177        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
178
179        // 最后判断一下 where 中是否用到了 join 的表
180        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
181    }
182
183    @SafeVarargs
184    public static <T, R> Page<R> doPaginate(
185        BaseMapper<T> mapper,
186        Page<R> page,
187        QueryWrapper queryWrapper,
188        Class<R> asType,
189        boolean withRelations,
190        Consumer<FieldQueryBuilder<R>>... consumers
191    ) {
192        Long limitRows = CPI.getLimitRows(queryWrapper);
193        Long limitOffset = CPI.getLimitOffset(queryWrapper);
194        try {
195            // 只有 totalRow 小于 0 的时候才会去查询总量
196            // 这样方便用户做总数缓存,而非每次都要去查询总量
197            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
198
199            if (page.getTotalRow() < 0) {
200
201                QueryWrapper countQueryWrapper;
202
203                if (page.needOptimizeCountQuery()) {
204                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
205                } else {
206                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
207                }
208
209                // optimize: 在 count 之前先去掉 limit 参数,避免 count 查询错误
210                CPI.setLimitRows(countQueryWrapper, null);
211                CPI.setLimitOffset(countQueryWrapper, null);
212
213                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
214            }
215
216            if (!page.hasRecords()) {
217                if (withRelations) {
218                    RelationManager.clearConfigIfNecessary();
219                }
220                return page;
221            }
222
223            queryWrapper.limit(page.offset(), page.getPageSize());
224
225            List<R> records;
226            if (asType != null) {
227                records = mapper.selectListByQueryAs(queryWrapper, asType);
228            } else {
229                // noinspection unchecked
230                records = (List<R>) mapper.selectListByQuery(queryWrapper);
231            }
232
233            if (withRelations) {
234                queryRelations(mapper, records);
235            }
236
237            queryFields(mapper, records, consumers);
238            page.setRecords(records);
239
240            return page;
241
242        } finally {
243            // 将之前设置的 limit 清除掉
244            // 保险起见把重置代码放到 finally 代码块中
245            CPI.setLimitRows(queryWrapper, limitRows);
246            CPI.setLimitOffset(queryWrapper, limitOffset);
247        }
248    }
249
250
251    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
252        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
253            return;
254        }
255
256        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
257        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
258            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
259            consumer.accept(fieldQueryBuilder);
260
261            FieldQuery fieldQuery = fieldQueryBuilder.build();
262
263            String className = fieldQuery.getEntityClass().getName();
264            String fieldName = fieldQuery.getFieldName();
265            String mapKey = className + '#' + fieldName;
266
267            fieldQueryMap.put(mapKey, fieldQuery);
268        }
269
270        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
271    }
272
273
274    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
275        if (entity != null) {
276            queryRelations(mapper, Collections.singletonList(entity));
277        } else {
278            RelationManager.clearConfigIfNecessary();
279        }
280        return entity;
281    }
282
283    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
284        RelationManager.queryRelations(mapper, entities);
285        return entities;
286    }
287
288
289    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
290        if (ClassUtil.canInstance(type.getModifiers())) {
291            return (Class<? extends Collection>) type;
292        }
293
294        if (List.class.isAssignableFrom(type)) {
295            return ArrayList.class;
296        }
297
298        if (Set.class.isAssignableFrom(type)) {
299            return HashSet.class;
300        }
301
302        throw new IllegalStateException("Field query can not support type: " + type.getName());
303    }
304
305
306    /**
307     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
308     */
309    public static <T> T getSelectOneResult(List<T> list) {
310        if (list == null || list.isEmpty()) {
311            return null;
312        }
313        int size = list.size();
314        if (size == 1) {
315            return list.get(0);
316        }
317        throw new TooManyResultsException(
318            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
319    }
320
321    public static long getLongNumber(List<Object> objects) {
322        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
323        if (object == null) {
324            return 0;
325        } else if (object instanceof Number) {
326            return ((Number) object).longValue();
327        } else {
328            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
329        }
330    }
331
332
333    public static Map<String, Object> preparedParams(BaseMapper<?> baseMapper, Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
334        Map<String, Object> newParams = new HashMap<>();
335
336        if (params != null) {
337            newParams.putAll(params);
338        }
339
340        newParams.put("pageOffset", page.offset());
341        newParams.put("pageNumber", page.getPageNumber());
342        newParams.put("pageSize", page.getPageSize());
343
344        DbType dbType = DialectFactory.getHintDbType();
345        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
346
347        if (queryWrapper != null) {
348            TableInfo tableInfo = TableInfoFactory.ofMapperClass(baseMapper.getClass());
349            tableInfo.appendConditions(null, queryWrapper);
350            preparedQueryWrapper(newParams, queryWrapper);
351        }
352
353        return newParams;
354    }
355
356
357    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
358        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
359        StringBuilder sqlBuilder = new StringBuilder();
360        char quote = 0;
361        int index = 0;
362        for (int i = 0; i < sql.length(); ++i) {
363            char ch = sql.charAt(i);
364            if (ch == '\'') {
365                if (quote == 0) {
366                    quote = ch;
367                } else if (quote == '\'') {
368                    quote = 0;
369                }
370            } else if (ch == '"') {
371                if (quote == 0) {
372                    quote = ch;
373                } else if (quote == '"') {
374                    quote = 0;
375                }
376            }
377            if (quote == 0 && ch == '?') {
378                sqlBuilder.append("#{qwParams_").append(index++).append("}");
379            } else {
380                sqlBuilder.append(ch);
381            }
382        }
383        params.put("qwSql", sqlBuilder.toString());
384        Object[] valueArray = CPI.getValueArray(queryWrapper);
385        for (int i = 0; i < valueArray.length; i++) {
386            params.put("qwParams_" + i, valueArray[i]);
387        }
388    }
389
390}